Fix _init_rope compatibility with transformers >= 5.x rope_scaling standardization
#10
by DennisHuang648 - opened
- modeling_minicpm_sala.py +29 -13
modeling_minicpm_sala.py
CHANGED
|
@@ -877,15 +877,23 @@ class MiniCPMAttention(nn.Module):
|
|
| 877 |
)
|
| 878 |
|
| 879 |
def _init_rope(self):
|
| 880 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
self.rotary_emb = MiniCPMRotaryEmbedding(
|
| 882 |
self.head_dim,
|
| 883 |
max_position_embeddings=self.max_position_embeddings,
|
| 884 |
base=self.rope_theta,
|
| 885 |
)
|
| 886 |
else:
|
| 887 |
-
|
| 888 |
-
scaling_factor = self.config.rope_scaling.get("factor", None)
|
| 889 |
if scaling_type == "linear":
|
| 890 |
self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
|
| 891 |
self.head_dim,
|
|
@@ -904,10 +912,10 @@ class MiniCPMAttention(nn.Module):
|
|
| 904 |
self.rotary_emb = MiniCPMLongRoPE(
|
| 905 |
self.head_dim,
|
| 906 |
max_position_embeddings=self.max_position_embeddings,
|
| 907 |
-
short_factor=
|
| 908 |
-
long_factor=
|
| 909 |
base=self.rope_theta,
|
| 910 |
-
original_max_position_embeddings=
|
| 911 |
"original_max_position_embeddings"
|
| 912 |
],
|
| 913 |
)
|
|
@@ -2142,15 +2150,23 @@ class LightningAttention(nn.Module):
|
|
| 2142 |
self._init_rope()
|
| 2143 |
|
| 2144 |
def _init_rope(self):
|
| 2145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2146 |
self.rotary_emb = MiniCPMRotaryEmbedding(
|
| 2147 |
self.head_dim,
|
| 2148 |
max_position_embeddings=self.config.max_position_embeddings,
|
| 2149 |
base=self.config.rope_theta,
|
| 2150 |
)
|
| 2151 |
else:
|
| 2152 |
-
|
| 2153 |
-
scaling_factor = self.config.rope_scaling.get("factor", None)
|
| 2154 |
if scaling_type == "linear":
|
| 2155 |
self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
|
| 2156 |
self.head_dim,
|
|
@@ -2169,10 +2185,10 @@ class LightningAttention(nn.Module):
|
|
| 2169 |
self.rotary_emb = MiniCPMLongRoPE(
|
| 2170 |
self.head_dim,
|
| 2171 |
max_position_embeddings=self.config.max_position_embeddings,
|
| 2172 |
-
short_factor=
|
| 2173 |
-
long_factor=
|
| 2174 |
base=self.config.rope_theta,
|
| 2175 |
-
original_max_position_embeddings=
|
| 2176 |
"original_max_position_embeddings"
|
| 2177 |
],
|
| 2178 |
)
|
|
@@ -3274,4 +3290,4 @@ class MiniCPMSALAForSequenceClassification(MiniCPMSALAPreTrainedModel):
|
|
| 3274 |
past_key_values=transformer_outputs.past_key_values,
|
| 3275 |
hidden_states=transformer_outputs.hidden_states,
|
| 3276 |
attentions=transformer_outputs.attentions,
|
| 3277 |
-
)
|
|
|
|
| 877 |
)
|
| 878 |
|
| 879 |
def _init_rope(self):
|
| 880 |
+
# transformers>=4.43 standardizes rope_scaling: a missing/None
|
| 881 |
+
# rope_scaling is auto-filled to {"rope_type": "default", "factor": 1.0}
|
| 882 |
+
# at config-load time. Treat both the original None case and the
|
| 883 |
+
# standardized "default" as no scaling so loading does not raise on
|
| 884 |
+
# newer transformers releases.
|
| 885 |
+
rope_scaling = self.config.rope_scaling
|
| 886 |
+
scaling_type = None
|
| 887 |
+
if isinstance(rope_scaling, dict):
|
| 888 |
+
scaling_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
|
| 889 |
+
if rope_scaling is None or scaling_type in (None, "default"):
|
| 890 |
self.rotary_emb = MiniCPMRotaryEmbedding(
|
| 891 |
self.head_dim,
|
| 892 |
max_position_embeddings=self.max_position_embeddings,
|
| 893 |
base=self.rope_theta,
|
| 894 |
)
|
| 895 |
else:
|
| 896 |
+
scaling_factor = rope_scaling.get("factor", None)
|
|
|
|
| 897 |
if scaling_type == "linear":
|
| 898 |
self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
|
| 899 |
self.head_dim,
|
|
|
|
| 912 |
self.rotary_emb = MiniCPMLongRoPE(
|
| 913 |
self.head_dim,
|
| 914 |
max_position_embeddings=self.max_position_embeddings,
|
| 915 |
+
short_factor=rope_scaling["short_factor"],
|
| 916 |
+
long_factor=rope_scaling["long_factor"],
|
| 917 |
base=self.rope_theta,
|
| 918 |
+
original_max_position_embeddings=rope_scaling[
|
| 919 |
"original_max_position_embeddings"
|
| 920 |
],
|
| 921 |
)
|
|
|
|
| 2150 |
self._init_rope()
|
| 2151 |
|
| 2152 |
def _init_rope(self):
|
| 2153 |
+
# transformers>=4.43 standardizes rope_scaling: a missing/None
|
| 2154 |
+
# rope_scaling is auto-filled to {"rope_type": "default", "factor": 1.0}
|
| 2155 |
+
# at config-load time. Treat both the original None case and the
|
| 2156 |
+
# standardized "default" as no scaling so loading does not raise on
|
| 2157 |
+
# newer transformers releases.
|
| 2158 |
+
rope_scaling = self.config.rope_scaling
|
| 2159 |
+
scaling_type = None
|
| 2160 |
+
if isinstance(rope_scaling, dict):
|
| 2161 |
+
scaling_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
|
| 2162 |
+
if rope_scaling is None or scaling_type in (None, "default"):
|
| 2163 |
self.rotary_emb = MiniCPMRotaryEmbedding(
|
| 2164 |
self.head_dim,
|
| 2165 |
max_position_embeddings=self.config.max_position_embeddings,
|
| 2166 |
base=self.config.rope_theta,
|
| 2167 |
)
|
| 2168 |
else:
|
| 2169 |
+
scaling_factor = rope_scaling.get("factor", None)
|
|
|
|
| 2170 |
if scaling_type == "linear":
|
| 2171 |
self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
|
| 2172 |
self.head_dim,
|
|
|
|
| 2185 |
self.rotary_emb = MiniCPMLongRoPE(
|
| 2186 |
self.head_dim,
|
| 2187 |
max_position_embeddings=self.config.max_position_embeddings,
|
| 2188 |
+
short_factor=rope_scaling["short_factor"],
|
| 2189 |
+
long_factor=rope_scaling["long_factor"],
|
| 2190 |
base=self.config.rope_theta,
|
| 2191 |
+
original_max_position_embeddings=rope_scaling[
|
| 2192 |
"original_max_position_embeddings"
|
| 2193 |
],
|
| 2194 |
)
|
|
|
|
| 3290 |
past_key_values=transformer_outputs.past_key_values,
|
| 3291 |
hidden_states=transformer_outputs.hidden_states,
|
| 3292 |
attentions=transformer_outputs.attentions,
|
| 3293 |
+
)
|