macto commited on
Commit
d4cd5f8
·
verified ·
1 Parent(s): 035a83a

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +0 -1
modeling_qwen.py CHANGED
@@ -180,7 +180,6 @@ class FlashSelfAttention(torch.nn.Module):
180
  return rearrange(output, '(b s) ... -> b s ...', b=batch)
181
 
182
  def forward(self, q, k, v, attention_mask=None):
183
- q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(to(torch.bfloat16))
184
  assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
185
  assert all((i.is_cuda for i in (q, k, v)))
186
  batch_size, seqlen_q = q.shape[0], q.shape[1]
 
180
  return rearrange(output, '(b s) ... -> b s ...', b=batch)
181
 
182
  def forward(self, q, k, v, attention_mask=None):
 
183
  assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
184
  assert all((i.is_cuda for i in (q, k, v)))
185
  batch_size, seqlen_q = q.shape[0], q.shape[1]