Fix _init_weights and RotaryEmbedding for transformers v5.x compatibility
#10
by
apsys - opened
- modeling_bailing_moe_v2.py +19 -7
modeling_bailing_moe_v2.py
CHANGED
|
@@ -201,9 +201,21 @@ class BailingMoeV2RotaryEmbedding(nn.Module):
|
|
| 201 |
self.original_max_seq_len = config.max_position_embeddings
|
| 202 |
|
| 203 |
self.config = config
|
| 204 |
-
self.
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 208 |
self.original_inv_freq = self.inv_freq
|
| 209 |
|
|
@@ -1060,13 +1072,13 @@ class BailingMoeV2PreTrainedModel(PreTrainedModel):
|
|
| 1060 |
def _init_weights(self, module):
|
| 1061 |
std = self.config.initializer_range
|
| 1062 |
if isinstance(module, nn.Linear):
|
| 1063 |
-
|
| 1064 |
if module.bias is not None:
|
| 1065 |
-
|
| 1066 |
elif isinstance(module, nn.Embedding):
|
| 1067 |
-
|
| 1068 |
if module.padding_idx is not None:
|
| 1069 |
-
module.weight
|
| 1070 |
|
| 1071 |
|
| 1072 |
BAILINGMOEV2_INPUTS_DOCSTRING = r"""
|
|
|
|
| 201 |
self.original_max_seq_len = config.max_position_embeddings
|
| 202 |
|
| 203 |
self.config = config
|
| 204 |
+
if self.rope_type == "default":
|
| 205 |
+
# ROPE_INIT_FUNCTIONS does not contain "default"; compute inv_freq
|
| 206 |
+
# directly. Use explicit float32 dtype to prevent bf16 overflow
|
| 207 |
+
# when the model is instantiated under torch_dtype=bfloat16 (e.g.
|
| 208 |
+
# via from_pretrained). With rope_theta=600000, base**x overflows
|
| 209 |
+
# bf16 (max ~65504) for most frequency bands.
|
| 210 |
+
base = float(getattr(config, "rope_theta", 10000.0))
|
| 211 |
+
dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 212 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 213 |
+
if device is not None:
|
| 214 |
+
inv_freq = inv_freq.to(device)
|
| 215 |
+
self.attention_scaling = 1.0
|
| 216 |
+
else:
|
| 217 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 218 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 219 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 220 |
self.original_inv_freq = self.inv_freq
|
| 221 |
|
|
|
|
| 1072 |
def _init_weights(self, module):
|
| 1073 |
std = self.config.initializer_range
|
| 1074 |
if isinstance(module, nn.Linear):
|
| 1075 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 1076 |
if module.bias is not None:
|
| 1077 |
+
torch.nn.init.zeros_(module.bias)
|
| 1078 |
elif isinstance(module, nn.Embedding):
|
| 1079 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 1080 |
if module.padding_idx is not None:
|
| 1081 |
+
torch.nn.init.zeros_(module.weight[module.padding_idx])
|
| 1082 |
|
| 1083 |
|
| 1084 |
BAILINGMOEV2_INPUTS_DOCSTRING = r"""
|