katuni4ka commited on
Commit
7ad5482
·
verified ·
1 Parent(s): 19f906c

Upload modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +3 -3
modeling_baichuan.py CHANGED
@@ -405,10 +405,10 @@ class BaichuanModel(BaichuanPreTrainedModel):
405
 
406
  if attention_mask is not None:
407
  if len(attention_mask.shape) == 2:
408
- expanded_mask = attention_mask.to(alibi_mask.dtype)
409
  expanded_mask = torch.tril(
410
- torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
411
- ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
412
  else:
413
  expanded_mask = attention_mask
414
  bsz = inputs_embeds.size(0)
 
405
 
406
  if attention_mask is not None:
407
  if len(attention_mask.shape) == 2:
408
+ expanded_mask = attention_mask.to(torch.float32)
409
  expanded_mask = torch.tril(
410
+ torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0).to(torch.float32)
411
+ ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0).to(torch.float32)
412
  else:
413
  expanded_mask = attention_mask
414
  bsz = inputs_embeds.size(0)