Fix _init_rope compatibility with transformers >= 5.x rope_scaling standardization

#10
Files changed (1) hide show
  1. modeling_minicpm_sala.py +29 -13
modeling_minicpm_sala.py CHANGED
@@ -877,15 +877,23 @@ class MiniCPMAttention(nn.Module):
877
  )
878
 
879
  def _init_rope(self):
880
- if self.config.rope_scaling is None:
 
 
 
 
 
 
 
 
 
881
  self.rotary_emb = MiniCPMRotaryEmbedding(
882
  self.head_dim,
883
  max_position_embeddings=self.max_position_embeddings,
884
  base=self.rope_theta,
885
  )
886
  else:
887
- scaling_type = self.config.rope_scaling["rope_type"]
888
- scaling_factor = self.config.rope_scaling.get("factor", None)
889
  if scaling_type == "linear":
890
  self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
891
  self.head_dim,
@@ -904,10 +912,10 @@ class MiniCPMAttention(nn.Module):
904
  self.rotary_emb = MiniCPMLongRoPE(
905
  self.head_dim,
906
  max_position_embeddings=self.max_position_embeddings,
907
- short_factor=self.config.rope_scaling["short_factor"],
908
- long_factor=self.config.rope_scaling["long_factor"],
909
  base=self.rope_theta,
910
- original_max_position_embeddings=self.config.rope_scaling[
911
  "original_max_position_embeddings"
912
  ],
913
  )
@@ -2142,15 +2150,23 @@ class LightningAttention(nn.Module):
2142
  self._init_rope()
2143
 
2144
  def _init_rope(self):
2145
- if self.config.rope_scaling is None:
 
 
 
 
 
 
 
 
 
2146
  self.rotary_emb = MiniCPMRotaryEmbedding(
2147
  self.head_dim,
2148
  max_position_embeddings=self.config.max_position_embeddings,
2149
  base=self.config.rope_theta,
2150
  )
2151
  else:
2152
- scaling_type = self.config.rope_scaling["rope_type"]
2153
- scaling_factor = self.config.rope_scaling.get("factor", None)
2154
  if scaling_type == "linear":
2155
  self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
2156
  self.head_dim,
@@ -2169,10 +2185,10 @@ class LightningAttention(nn.Module):
2169
  self.rotary_emb = MiniCPMLongRoPE(
2170
  self.head_dim,
2171
  max_position_embeddings=self.config.max_position_embeddings,
2172
- short_factor=self.config.rope_scaling["short_factor"],
2173
- long_factor=self.config.rope_scaling["long_factor"],
2174
  base=self.config.rope_theta,
2175
- original_max_position_embeddings=self.config.rope_scaling[
2176
  "original_max_position_embeddings"
2177
  ],
2178
  )
@@ -3274,4 +3290,4 @@ class MiniCPMSALAForSequenceClassification(MiniCPMSALAPreTrainedModel):
3274
  past_key_values=transformer_outputs.past_key_values,
3275
  hidden_states=transformer_outputs.hidden_states,
3276
  attentions=transformer_outputs.attentions,
3277
- )
 
877
  )
878
 
879
  def _init_rope(self):
880
+ # transformers>=4.43 standardizes rope_scaling: a missing/None
881
+ # rope_scaling is auto-filled to {"rope_type": "default", "factor": 1.0}
882
+ # at config-load time. Treat both the original None case and the
883
+ # standardized "default" as no scaling so loading does not raise on
884
+ # newer transformers releases.
885
+ rope_scaling = self.config.rope_scaling
886
+ scaling_type = None
887
+ if isinstance(rope_scaling, dict):
888
+ scaling_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
889
+ if rope_scaling is None or scaling_type in (None, "default"):
890
  self.rotary_emb = MiniCPMRotaryEmbedding(
891
  self.head_dim,
892
  max_position_embeddings=self.max_position_embeddings,
893
  base=self.rope_theta,
894
  )
895
  else:
896
+ scaling_factor = rope_scaling.get("factor", None)
 
897
  if scaling_type == "linear":
898
  self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
899
  self.head_dim,
 
912
  self.rotary_emb = MiniCPMLongRoPE(
913
  self.head_dim,
914
  max_position_embeddings=self.max_position_embeddings,
915
+ short_factor=rope_scaling["short_factor"],
916
+ long_factor=rope_scaling["long_factor"],
917
  base=self.rope_theta,
918
+ original_max_position_embeddings=rope_scaling[
919
  "original_max_position_embeddings"
920
  ],
921
  )
 
2150
  self._init_rope()
2151
 
2152
  def _init_rope(self):
2153
+ # transformers>=4.43 standardizes rope_scaling: a missing/None
2154
+ # rope_scaling is auto-filled to {"rope_type": "default", "factor": 1.0}
2155
+ # at config-load time. Treat both the original None case and the
2156
+ # standardized "default" as no scaling so loading does not raise on
2157
+ # newer transformers releases.
2158
+ rope_scaling = self.config.rope_scaling
2159
+ scaling_type = None
2160
+ if isinstance(rope_scaling, dict):
2161
+ scaling_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
2162
+ if rope_scaling is None or scaling_type in (None, "default"):
2163
  self.rotary_emb = MiniCPMRotaryEmbedding(
2164
  self.head_dim,
2165
  max_position_embeddings=self.config.max_position_embeddings,
2166
  base=self.config.rope_theta,
2167
  )
2168
  else:
2169
+ scaling_factor = rope_scaling.get("factor", None)
 
2170
  if scaling_type == "linear":
2171
  self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
2172
  self.head_dim,
 
2185
  self.rotary_emb = MiniCPMLongRoPE(
2186
  self.head_dim,
2187
  max_position_embeddings=self.config.max_position_embeddings,
2188
+ short_factor=rope_scaling["short_factor"],
2189
+ long_factor=rope_scaling["long_factor"],
2190
  base=self.config.rope_theta,
2191
+ original_max_position_embeddings=rope_scaling[
2192
  "original_max_position_embeddings"
2193
  ],
2194
  )
 
3290
  past_key_values=transformer_outputs.past_key_values,
3291
  hidden_states=transformer_outputs.hidden_states,
3292
  attentions=transformer_outputs.attentions,
3293
+ )