damerajee commited on
Commit
0d335f8
·
verified ·
1 Parent(s): fe8a29e

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +2 -3
modeling_Llamoe.py CHANGED
@@ -660,9 +660,8 @@ class LlamoeSdpaAttention(LlamoeAttention):
660
  key_states = repeat_kv(key_states, self.num_key_value_groups)
661
  value_states = repeat_kv(value_states, self.num_key_value_groups)
662
 
663
- causal_mask = attention_mask
664
- if attention_mask is not None and cache_position is not None:
665
- causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
666
 
667
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
668
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
660
  key_states = repeat_kv(key_states, self.num_key_value_groups)
661
  value_states = repeat_kv(value_states, self.num_key_value_groups)
662
 
663
+
664
+ causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
 
665
 
666
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
667
  # Reference: https://github.com/pytorch/pytorch/issues/112577.