Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +1 -0
modeling_Llamoe.py
CHANGED
|
@@ -662,6 +662,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
| 662 |
|
| 663 |
|
| 664 |
causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
|
|
|
|
| 665 |
|
| 666 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 667 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
|
|
| 662 |
|
| 663 |
|
| 664 |
causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
|
| 665 |
+
causal_mask = causal_mask.to(dtype=query_states.dtype)
|
| 666 |
|
| 667 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 668 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|