Spaces:
Running on T4
Running on T4
fix: move k_lens to GPU in SDPA fallback (tested locally)
Browse files- model_manager.py +1 -0
model_manager.py
CHANGED
|
@@ -130,6 +130,7 @@ class ModelManager:
|
|
| 130 |
" attn_mask = None\n"
|
| 131 |
" is_causal_flag = causal\n"
|
| 132 |
" if k_lens is not None:\n"
|
|
|
|
| 133 |
" valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
|
| 134 |
" attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
|
| 135 |
" is_causal_flag = False\n"
|
|
|
|
| 130 |
" attn_mask = None\n"
|
| 131 |
" is_causal_flag = causal\n"
|
| 132 |
" if k_lens is not None:\n"
|
| 133 |
+
" k_lens = k_lens.to(q.device)\n"
|
| 134 |
" valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
|
| 135 |
" attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
|
| 136 |
" is_causal_flag = False\n"
|