Spaces:
Running on Zero
Running on Zero
Harry Coultas Blum commited on
Commit ·
00bf9af
1
Parent(s): fee1df4
Fix dtype mismatch for ZeroGPU SDPA patching
Browse files- vui/model.py +3 -1
vui/model.py
CHANGED
|
@@ -100,7 +100,7 @@ class MHA(nn.Module):
|
|
| 100 |
|
| 101 |
dropout_p = self.dropout if self.training else 0.0
|
| 102 |
|
| 103 |
-
qkv = self.Wqkv(x)
|
| 104 |
if self.n_heads == self.n_kv_heads:
|
| 105 |
qkv = rearrange(
|
| 106 |
qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
|
|
@@ -125,6 +125,8 @@ class MHA(nn.Module):
|
|
| 125 |
if self.kv_cache is not None:
|
| 126 |
k, v = self.kv_cache.update(input_pos, k, v)
|
| 127 |
|
|
|
|
|
|
|
| 128 |
if self.n_reps > 1:
|
| 129 |
k = repeat_kv(k, self.n_reps)
|
| 130 |
v = repeat_kv(v, self.n_reps)
|
|
|
|
| 100 |
|
| 101 |
dropout_p = self.dropout if self.training else 0.0
|
| 102 |
|
| 103 |
+
qkv = self.Wqkv(x).to(x.dtype)
|
| 104 |
if self.n_heads == self.n_kv_heads:
|
| 105 |
qkv = rearrange(
|
| 106 |
qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
|
|
|
|
| 125 |
if self.kv_cache is not None:
|
| 126 |
k, v = self.kv_cache.update(input_pos, k, v)
|
| 127 |
|
| 128 |
+
q, k, v = q.to(x.dtype), k.to(x.dtype), v.to(x.dtype)
|
| 129 |
+
|
| 130 |
if self.n_reps > 1:
|
| 131 |
k = repeat_kv(k, self.n_reps)
|
| 132 |
v = repeat_kv(v, self.n_reps)
|