Fix _init_weights and RotaryEmbedding for transformers v5.x compatibility

#10
by apsys - opened
Files changed (1) hide show
  1. modeling_bailing_moe_v2.py +19 -7
modeling_bailing_moe_v2.py CHANGED
@@ -201,9 +201,21 @@ class BailingMoeV2RotaryEmbedding(nn.Module):
201
  self.original_max_seq_len = config.max_position_embeddings
202
 
203
  self.config = config
204
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
205
-
206
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
 
 
 
 
 
 
 
207
  self.register_buffer("inv_freq", inv_freq, persistent=False)
208
  self.original_inv_freq = self.inv_freq
209
 
@@ -1060,13 +1072,13 @@ class BailingMoeV2PreTrainedModel(PreTrainedModel):
1060
  def _init_weights(self, module):
1061
  std = self.config.initializer_range
1062
  if isinstance(module, nn.Linear):
1063
- module.weight.data.normal_(mean=0.0, std=std)
1064
  if module.bias is not None:
1065
- module.bias.data.zero_()
1066
  elif isinstance(module, nn.Embedding):
1067
- module.weight.data.normal_(mean=0.0, std=std)
1068
  if module.padding_idx is not None:
1069
- module.weight.data[module.padding_idx].zero_()
1070
 
1071
 
1072
  BAILINGMOEV2_INPUTS_DOCSTRING = r"""
 
201
  self.original_max_seq_len = config.max_position_embeddings
202
 
203
  self.config = config
204
+ if self.rope_type == "default":
205
+ # ROPE_INIT_FUNCTIONS does not contain "default"; compute inv_freq
206
+ # directly. Use explicit float32 dtype to prevent bf16 overflow
207
+ # when the model is instantiated under torch_dtype=bfloat16 (e.g.
208
+ # via from_pretrained). With rope_theta=600000, base**x overflows
209
+ # bf16 (max ~65504) for most frequency bands.
210
+ base = float(getattr(config, "rope_theta", 10000.0))
211
+ dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
212
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
213
+ if device is not None:
214
+ inv_freq = inv_freq.to(device)
215
+ self.attention_scaling = 1.0
216
+ else:
217
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
218
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
219
  self.register_buffer("inv_freq", inv_freq, persistent=False)
220
  self.original_inv_freq = self.inv_freq
221
 
 
1072
  def _init_weights(self, module):
1073
  std = self.config.initializer_range
1074
  if isinstance(module, nn.Linear):
1075
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
1076
  if module.bias is not None:
1077
+ torch.nn.init.zeros_(module.bias)
1078
  elif isinstance(module, nn.Embedding):
1079
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
1080
  if module.padding_idx is not None:
1081
+ torch.nn.init.zeros_(module.weight[module.padding_idx])
1082
 
1083
 
1084
  BAILINGMOEV2_INPUTS_DOCSTRING = r"""