Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +1 -1
modeling_Llamoe.py
CHANGED
|
@@ -662,7 +662,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
| 662 |
|
| 663 |
causal_mask = attention_mask
|
| 664 |
if attention_mask is not None and cache_position is not None:
|
| 665 |
-
causal_mask =
|
| 666 |
|
| 667 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 668 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
|
|
| 662 |
|
| 663 |
causal_mask = attention_mask
|
| 664 |
if attention_mask is not None and cache_position is not None:
|
| 665 |
+
causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
|
| 666 |
|
| 667 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 668 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|