damerajee commited on
Commit
475027c
·
verified ·
1 Parent(s): 0d335f8

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +1 -0
modeling_Llamoe.py CHANGED
@@ -662,6 +662,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
662
 
663
 
664
  causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
 
665
 
666
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
667
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
662
 
663
 
664
  causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
665
+ causal_mask = causal_mask.to(dtype=query_states.dtype)
666
 
667
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
668
  # Reference: https://github.com/pytorch/pytorch/issues/112577.