Harry Coultas Blum commited on
Commit
00bf9af
·
1 Parent(s): fee1df4

Fix dtype mismatch for ZeroGPU SDPA patching

Browse files
Files changed (1) hide show
  1. vui/model.py +3 -1
vui/model.py CHANGED
@@ -100,7 +100,7 @@ class MHA(nn.Module):
100
 
101
  dropout_p = self.dropout if self.training else 0.0
102
 
103
- qkv = self.Wqkv(x)
104
  if self.n_heads == self.n_kv_heads:
105
  qkv = rearrange(
106
  qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
@@ -125,6 +125,8 @@ class MHA(nn.Module):
125
  if self.kv_cache is not None:
126
  k, v = self.kv_cache.update(input_pos, k, v)
127
 
 
 
128
  if self.n_reps > 1:
129
  k = repeat_kv(k, self.n_reps)
130
  v = repeat_kv(v, self.n_reps)
 
100
 
101
  dropout_p = self.dropout if self.training else 0.0
102
 
103
+ qkv = self.Wqkv(x).to(x.dtype)
104
  if self.n_heads == self.n_kv_heads:
105
  qkv = rearrange(
106
  qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
 
125
  if self.kv_cache is not None:
126
  k, v = self.kv_cache.update(input_pos, k, v)
127
 
128
+ q, k, v = q.to(x.dtype), k.to(x.dtype), v.to(x.dtype)
129
+
130
  if self.n_reps > 1:
131
  k = repeat_kv(k, self.n_reps)
132
  v = repeat_kv(v, self.n_reps)