Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +2 -3
modeling_Llamoe.py
CHANGED
|
@@ -660,9 +660,8 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
| 660 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 661 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 662 |
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
|
| 666 |
|
| 667 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 668 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
|
|
| 660 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 661 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 662 |
|
| 663 |
+
|
| 664 |
+
causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
|
|
|
|
| 665 |
|
| 666 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 667 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|