damerajee commited on
Commit
e09acd1
·
verified ·
1 Parent(s): cce8817

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +1 -1
modeling_Llamoe.py CHANGED
@@ -662,7 +662,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
662
 
663
  causal_mask = attention_mask
664
  if attention_mask is not None and cache_position is not None:
665
- causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
666
 
667
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
668
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
662
 
663
  causal_mask = attention_mask
664
  if attention_mask is not None and cache_position is not None:
665
+ causal_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device))
666
 
667
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
668
  # Reference: https://github.com/pytorch/pytorch/issues/112577.