H-Liu1997 commited on
Commit
7d73321
·
1 Parent(s): 6c6483b

fix: move k_lens to GPU in SDPA fallback (tested locally)

Browse files
Files changed (1) hide show
  1. model_manager.py +1 -0
model_manager.py CHANGED
@@ -130,6 +130,7 @@ class ModelManager:
130
  " attn_mask = None\n"
131
  " is_causal_flag = causal\n"
132
  " if k_lens is not None:\n"
 
133
  " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
134
  " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
135
  " is_causal_flag = False\n"
 
130
  " attn_mask = None\n"
131
  " is_causal_flag = causal\n"
132
  " if k_lens is not None:\n"
133
+ " k_lens = k_lens.to(q.device)\n"
134
  " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
135
  " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
136
  " is_causal_flag = False\n"