fix: align RotaryEmbedding and _init_weights with Qwen2Moe for transformers compat

#2
by kashif HF Staff - opened
Files changed (1) hide show
  1. modeling_llada2_moe.py +32 -27
modeling_llada2_moe.py CHANGED
@@ -20,7 +20,7 @@
20
 
21
  import math
22
  import warnings
23
- from typing import List, Optional, Tuple, Union
24
 
25
  import torch
26
  import torch.nn.functional as F
@@ -52,7 +52,6 @@ from transformers.utils import (
52
  logging,
53
  replace_return_docstrings,
54
  )
55
- from transformers.utils.import_utils import is_torch_fx_available
56
  from .configuration_llada2_moe import LLaDA2MoeConfig
57
  from transformers.generation.utils import GenerationMixin
58
 
@@ -62,13 +61,6 @@ if is_flash_attn_2_available():
62
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
63
 
64
 
65
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
66
- # It means that the function will not be traced through and simply appear as a node in the graph.
67
- if is_torch_fx_available():
68
- if not is_torch_greater_or_equal_than_1_13:
69
- import torch.fx
70
-
71
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
72
 
73
 
74
  logger = logging.get_logger(__name__)
@@ -111,24 +103,41 @@ ALL_LAYERNORM_LAYERS.append(LLaDA2MoeRMSNorm)
111
 
112
 
113
  class LLaDA2MoeRotaryEmbedding(nn.Module):
 
 
114
  def __init__(self, config: LLaDA2MoeConfig, device=None):
115
  super().__init__()
116
- # BC: "rope_type" was originally "type"
117
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
118
- self.rope_type = config.rope_scaling.get(
119
- "rope_type", config.rope_scaling.get("type")
120
- )
121
- else:
122
- self.rope_type = "default"
123
  self.max_seq_len_cached = config.max_position_embeddings
124
  self.original_max_seq_len = config.max_position_embeddings
125
 
126
  self.config = config
127
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
128
 
129
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
130
  self.register_buffer("inv_freq", inv_freq, persistent=False)
131
- self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  @torch.no_grad()
134
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -1052,16 +1061,12 @@ class LLaDA2MoePreTrainedModel(PreTrainedModel):
1052
  _supports_sdpa = True
1053
  _supports_cache_class = True
1054
 
 
1055
  def _init_weights(self, module):
 
1056
  std = self.config.initializer_range
1057
- if isinstance(module, nn.Linear):
1058
- module.weight.data.normal_(mean=0.0, std=std)
1059
- if module.bias is not None:
1060
- module.bias.data.zero_()
1061
- elif isinstance(module, nn.Embedding):
1062
- module.weight.data.normal_(mean=0.0, std=std)
1063
- if module.padding_idx is not None:
1064
- module.weight.data[module.padding_idx].zero_()
1065
 
1066
 
1067
  LLADA2MOE_INPUTS_DOCSTRING = r"""
 
20
 
21
  import math
22
  import warnings
23
+ from typing import Callable, List, Optional, Tuple, Union
24
 
25
  import torch
26
  import torch.nn.functional as F
 
52
  logging,
53
  replace_return_docstrings,
54
  )
 
55
  from .configuration_llada2_moe import LLaDA2MoeConfig
56
  from transformers.generation.utils import GenerationMixin
57
 
 
61
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
 
63
 
 
 
 
 
 
 
 
64
 
65
 
66
  logger = logging.get_logger(__name__)
 
103
 
104
 
105
  class LLaDA2MoeRotaryEmbedding(nn.Module):
106
+ inv_freq: torch.Tensor # fix linting for register_buffer
107
+
108
  def __init__(self, config: LLaDA2MoeConfig, device=None):
109
  super().__init__()
 
 
 
 
 
 
 
110
  self.max_seq_len_cached = config.max_position_embeddings
111
  self.original_max_seq_len = config.max_position_embeddings
112
 
113
  self.config = config
 
114
 
115
+ self.rope_type = self.config.rope_parameters["rope_type"]
116
+ rope_init_fn: Callable = self.compute_default_rope_parameters
117
+ if self.rope_type != "default":
118
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
119
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
120
+
121
  self.register_buffer("inv_freq", inv_freq, persistent=False)
122
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
123
+
124
+ @staticmethod
125
+ def compute_default_rope_parameters(
126
+ config: LLaDA2MoeConfig = None,
127
+ device=None,
128
+ seq_len: int = None,
129
+ ):
130
+ base = config.rope_parameters["rope_theta"]
131
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
132
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
133
+ dim = int(head_dim * partial_rotary_factor)
134
+
135
+ attention_factor = 1.0 # Unused in this type of RoPE
136
+
137
+ inv_freq = 1.0 / (
138
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
139
+ )
140
+ return inv_freq, attention_factor
141
 
142
  @torch.no_grad()
143
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
 
1061
  _supports_sdpa = True
1062
  _supports_cache_class = True
1063
 
1064
+ @torch.no_grad()
1065
  def _init_weights(self, module):
1066
+ super()._init_weights(module)
1067
  std = self.config.initializer_range
1068
+ if isinstance(module, LLaDA2MoeGate):
1069
+ nn.init.normal_(module.weight, mean=0.0, std=std)
 
 
 
 
 
 
1070
 
1071
 
1072
  LLADA2MOE_INPUTS_DOCSTRING = r"""