Fix: Align mask preparation logic for Eager attention to prevent corrupted outputs
#4
by
haibo8 - opened
- modeling_llada2_moe.py +4 -35
modeling_llada2_moe.py
CHANGED
|
@@ -29,7 +29,6 @@ from torch.nn import CrossEntropyLoss
|
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
from transformers.modeling_attn_mask_utils import (
|
| 32 |
-
_prepare_4d_causal_attention_mask,
|
| 33 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 34 |
)
|
| 35 |
from transformers.modeling_outputs import (
|
|
@@ -41,7 +40,6 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 41 |
from transformers.processing_utils import Unpack
|
| 42 |
from transformers.pytorch_utils import (
|
| 43 |
ALL_LAYERNORM_LAYERS,
|
| 44 |
-
is_torch_greater_or_equal_than_1_13,
|
| 45 |
)
|
| 46 |
from transformers.utils import (
|
| 47 |
TransformersKwargs,
|
|
@@ -50,20 +48,10 @@ from transformers.utils import (
|
|
| 50 |
logging,
|
| 51 |
replace_return_docstrings,
|
| 52 |
)
|
| 53 |
-
from transformers.utils.import_utils import is_torch_fx_available
|
| 54 |
from .configuration_llada2_moe import LLaDA2MoeConfig
|
| 55 |
from transformers.generation.utils import GenerationMixin
|
| 56 |
|
| 57 |
|
| 58 |
-
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 59 |
-
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 60 |
-
if is_torch_fx_available():
|
| 61 |
-
if not is_torch_greater_or_equal_than_1_13:
|
| 62 |
-
import torch.fx
|
| 63 |
-
|
| 64 |
-
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
logger = logging.get_logger(__name__)
|
| 68 |
|
| 69 |
_CONFIG_FOR_DOC = "LLaDA2MoeConfig"
|
|
@@ -403,9 +391,7 @@ def eager_attention_forward(
|
|
| 403 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 404 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 405 |
|
| 406 |
-
attn_weights = (
|
| 407 |
-
torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 408 |
-
)
|
| 409 |
if attention_mask is not None:
|
| 410 |
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
|
| 411 |
|
|
@@ -877,18 +863,7 @@ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
|
|
| 877 |
device=inputs_embeds.device,
|
| 878 |
)
|
| 879 |
position_ids = position_ids.unsqueeze(0)
|
| 880 |
-
|
| 881 |
-
if self._use_flex_attention:
|
| 882 |
-
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
|
| 883 |
-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 884 |
-
attention_mask,
|
| 885 |
-
(batch_size, seq_length),
|
| 886 |
-
inputs_embeds,
|
| 887 |
-
past_seen_tokens,
|
| 888 |
-
)
|
| 889 |
-
elif self._use_sdpa and not output_attentions:
|
| 890 |
-
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 891 |
-
# the manual implementation that requires a 4D causal mask in all cases.
|
| 892 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 893 |
attention_mask,
|
| 894 |
(batch_size, seq_length),
|
|
@@ -896,14 +871,9 @@ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
|
|
| 896 |
past_seen_tokens,
|
| 897 |
)
|
| 898 |
else:
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
attention_mask,
|
| 902 |
-
(batch_size, seq_length),
|
| 903 |
-
inputs_embeds,
|
| 904 |
-
past_seen_tokens,
|
| 905 |
)
|
| 906 |
-
|
| 907 |
# embed positions
|
| 908 |
hidden_states = inputs_embeds
|
| 909 |
|
|
@@ -1431,4 +1401,3 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
|
|
| 1431 |
return generated_answer[
|
| 1432 |
:, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1
|
| 1433 |
]
|
| 1434 |
-
|
|
|
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
from transformers.modeling_attn_mask_utils import (
|
|
|
|
| 32 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 33 |
)
|
| 34 |
from transformers.modeling_outputs import (
|
|
|
|
| 40 |
from transformers.processing_utils import Unpack
|
| 41 |
from transformers.pytorch_utils import (
|
| 42 |
ALL_LAYERNORM_LAYERS,
|
|
|
|
| 43 |
)
|
| 44 |
from transformers.utils import (
|
| 45 |
TransformersKwargs,
|
|
|
|
| 48 |
logging,
|
| 49 |
replace_return_docstrings,
|
| 50 |
)
|
|
|
|
| 51 |
from .configuration_llada2_moe import LLaDA2MoeConfig
|
| 52 |
from transformers.generation.utils import GenerationMixin
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
logger = logging.get_logger(__name__)
|
| 56 |
|
| 57 |
_CONFIG_FOR_DOC = "LLaDA2MoeConfig"
|
|
|
|
| 391 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 392 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 393 |
|
| 394 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
|
|
|
|
|
| 395 |
if attention_mask is not None:
|
| 396 |
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
|
| 397 |
|
|
|
|
| 863 |
device=inputs_embeds.device,
|
| 864 |
)
|
| 865 |
position_ids = position_ids.unsqueeze(0)
|
| 866 |
+
if attention_mask.size() == (batch_size, 1, seq_length, seq_length):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 868 |
attention_mask,
|
| 869 |
(batch_size, seq_length),
|
|
|
|
| 871 |
past_seen_tokens,
|
| 872 |
)
|
| 873 |
else:
|
| 874 |
+
raise ValueError(
|
| 875 |
+
f"LLaDA2.0 only support block attention mask with shape: {(batch_size, 1, seq_length, seq_length)}, the input attention with shape {attention_mask.size()=}!"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
)
|
|
|
|
| 877 |
# embed positions
|
| 878 |
hidden_states = inputs_embeds
|
| 879 |
|
|
|
|
| 1401 |
return generated_answer[
|
| 1402 |
:, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1
|
| 1403 |
]
|
|
|