kashif HF Staff commited on
Commit
d8a5a92
·
verified ·
1 Parent(s): c555c2f

Use create_bidirectional_mask for backend-agnostic attention mask handling

Browse files
Files changed (1) hide show
  1. modeling_llada2_moe.py +6 -14
modeling_llada2_moe.py CHANGED
@@ -28,9 +28,7 @@ from torch.nn import CrossEntropyLoss
28
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import (
32
- _prepare_4d_causal_attention_mask_for_sdpa,
33
- )
34
  from transformers.modeling_outputs import (
35
  MoeModelOutputWithPast,
36
  MoeCausalLMOutputWithPast,
@@ -876,17 +874,11 @@ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
876
  device=inputs_embeds.device,
877
  )
878
  position_ids = position_ids.unsqueeze(0)
879
- if attention_mask.size() == (batch_size, 1, seq_length, seq_length):
880
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
881
- attention_mask,
882
- (batch_size, seq_length),
883
- inputs_embeds,
884
- past_seen_tokens,
885
- )
886
- else:
887
- raise ValueError(
888
- f"LLaDA2.0 only support block attention mask with shape: {(batch_size, 1, seq_length, seq_length)}, the input attention with shape {attention_mask.size()=}!"
889
- )
890
  # embed positions
891
  hidden_states = inputs_embeds
892
 
 
28
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.masking_utils import create_bidirectional_mask
 
 
32
  from transformers.modeling_outputs import (
33
  MoeModelOutputWithPast,
34
  MoeCausalLMOutputWithPast,
 
874
  device=inputs_embeds.device,
875
  )
876
  position_ids = position_ids.unsqueeze(0)
877
+ attention_mask = create_bidirectional_mask(
878
+ config=self.config,
879
+ inputs_embeds=inputs_embeds,
880
+ attention_mask=attention_mask,
881
+ )
 
 
 
 
 
 
882
  # embed positions
883
  hidden_states = inputs_embeds
884