Fix: Align mask preparation logic for Eager attention to prevent corrupted outputs

#4
Files changed (1) hide show
  1. modeling_llada2_moe.py +4 -35
modeling_llada2_moe.py CHANGED
@@ -29,7 +29,6 @@ from torch.nn import CrossEntropyLoss
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
  from transformers.modeling_attn_mask_utils import (
32
- _prepare_4d_causal_attention_mask,
33
  _prepare_4d_causal_attention_mask_for_sdpa,
34
  )
35
  from transformers.modeling_outputs import (
@@ -41,7 +40,6 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
  from transformers.processing_utils import Unpack
42
  from transformers.pytorch_utils import (
43
  ALL_LAYERNORM_LAYERS,
44
- is_torch_greater_or_equal_than_1_13,
45
  )
46
  from transformers.utils import (
47
  TransformersKwargs,
@@ -50,20 +48,10 @@ from transformers.utils import (
50
  logging,
51
  replace_return_docstrings,
52
  )
53
- from transformers.utils.import_utils import is_torch_fx_available
54
  from .configuration_llada2_moe import LLaDA2MoeConfig
55
  from transformers.generation.utils import GenerationMixin
56
 
57
 
58
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
59
- # It means that the function will not be traced through and simply appear as a node in the graph.
60
- if is_torch_fx_available():
61
- if not is_torch_greater_or_equal_than_1_13:
62
- import torch.fx
63
-
64
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
65
-
66
-
67
  logger = logging.get_logger(__name__)
68
 
69
  _CONFIG_FOR_DOC = "LLaDA2MoeConfig"
@@ -403,9 +391,7 @@ def eager_attention_forward(
403
  key_states = repeat_kv(key, module.num_key_value_groups)
404
  value_states = repeat_kv(value, module.num_key_value_groups)
405
 
406
- attn_weights = (
407
- torch.matmul(query, key_states.transpose(2, 3)) * scaling
408
- )
409
  if attention_mask is not None:
410
  attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
411
 
@@ -877,18 +863,7 @@ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
877
  device=inputs_embeds.device,
878
  )
879
  position_ids = position_ids.unsqueeze(0)
880
-
881
- if self._use_flex_attention:
882
- if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
883
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
884
- attention_mask,
885
- (batch_size, seq_length),
886
- inputs_embeds,
887
- past_seen_tokens,
888
- )
889
- elif self._use_sdpa and not output_attentions:
890
- # output_attentions=True can not be supported when using SDPA, and we fall back on
891
- # the manual implementation that requires a 4D causal mask in all cases.
892
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
893
  attention_mask,
894
  (batch_size, seq_length),
@@ -896,14 +871,9 @@ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
896
  past_seen_tokens,
897
  )
898
  else:
899
- # 4d mask is passed through the layers
900
- attention_mask = _prepare_4d_causal_attention_mask(
901
- attention_mask,
902
- (batch_size, seq_length),
903
- inputs_embeds,
904
- past_seen_tokens,
905
  )
906
-
907
  # embed positions
908
  hidden_states = inputs_embeds
909
 
@@ -1431,4 +1401,3 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
1431
  return generated_answer[
1432
  :, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1
1433
  ]
1434
-
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
  from transformers.modeling_attn_mask_utils import (
 
32
  _prepare_4d_causal_attention_mask_for_sdpa,
33
  )
34
  from transformers.modeling_outputs import (
 
40
  from transformers.processing_utils import Unpack
41
  from transformers.pytorch_utils import (
42
  ALL_LAYERNORM_LAYERS,
 
43
  )
44
  from transformers.utils import (
45
  TransformersKwargs,
 
48
  logging,
49
  replace_return_docstrings,
50
  )
 
51
  from .configuration_llada2_moe import LLaDA2MoeConfig
52
  from transformers.generation.utils import GenerationMixin
53
 
54
 
 
 
 
 
 
 
 
 
 
55
  logger = logging.get_logger(__name__)
56
 
57
  _CONFIG_FOR_DOC = "LLaDA2MoeConfig"
 
391
  key_states = repeat_kv(key, module.num_key_value_groups)
392
  value_states = repeat_kv(value, module.num_key_value_groups)
393
 
394
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
 
 
395
  if attention_mask is not None:
396
  attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
397
 
 
863
  device=inputs_embeds.device,
864
  )
865
  position_ids = position_ids.unsqueeze(0)
866
+ if attention_mask.size() == (batch_size, 1, seq_length, seq_length):
 
 
 
 
 
 
 
 
 
 
 
867
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
868
  attention_mask,
869
  (batch_size, seq_length),
 
871
  past_seen_tokens,
872
  )
873
  else:
874
+ raise ValueError(
875
+ f"LLaDA2.0 only support block attention mask with shape: {(batch_size, 1, seq_length, seq_length)}, the input attention with shape {attention_mask.size()=}!"
 
 
 
 
876
  )
 
877
  # embed positions
878
  hidden_states = inputs_embeds
879
 
 
1401
  return generated_answer[
1402
  :, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1
1403
  ]