Port to newer version of transformers package

#7
Files changed (2) hide show
  1. configuration_slimmoe.py +1 -1
  2. modeling_slimmoe.py +12 -6
configuration_slimmoe.py CHANGED
@@ -111,7 +111,7 @@ class PhiMoEConfig(PretrainedConfig):
111
  >>> configuration = model.config
112
  ```"""
113
 
114
- model_type = "phimoe"
115
  keys_to_ignore_at_inference = ["past_key_values"]
116
 
117
  def __init__(
 
111
  >>> configuration = model.config
112
  ```"""
113
 
114
+ model_type = "phimoe_slim" # renamed from "phimoe" to bypass transformers >=4.46 conversion_mapping
115
  keys_to_ignore_at_inference = ["past_key_values"]
116
 
117
  def __init__(
modeling_slimmoe.py CHANGED
@@ -46,7 +46,6 @@ from transformers.utils import (
46
  logging,
47
  replace_return_docstrings,
48
  )
49
- from transformers.utils.import_utils import is_torch_fx_available
50
  from .configuration_slimmoe import PhiMoEConfig
51
 
52
  from einops import rearrange
@@ -61,11 +60,10 @@ if is_flash_attn_2_available():
61
 
62
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
63
  # It means that the function will not be traced through and simply appear as a node in the graph.
64
- if is_torch_fx_available():
65
- if not is_torch_greater_or_equal_than_1_13:
66
- import torch.fx
67
 
68
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
69
 
70
 
71
  logger = logging.get_logger(__name__)
@@ -332,9 +330,17 @@ class PhiMoEAttention(nn.Module):
332
  base=self.rope_theta,
333
  )
334
  else:
335
- scaling_type = self.config.rope_scaling["type"]
 
336
  if scaling_type == "longrope":
337
  self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
 
 
 
 
 
 
 
338
  else:
339
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
340
 
 
46
  logging,
47
  replace_return_docstrings,
48
  )
 
49
  from .configuration_slimmoe import PhiMoEConfig
50
 
51
  from einops import rearrange
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
  # It means that the function will not be traced through and simply appear as a node in the graph.
63
+ if not is_torch_greater_or_equal_than_1_13:
64
+ import torch.fx
 
65
 
66
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
67
 
68
 
69
  logger = logging.get_logger(__name__)
 
330
  base=self.rope_theta,
331
  )
332
  else:
333
+ # "type" key was renamed to "rope_type" in transformers >=4.46; handle both
334
+ scaling_type = self.config.rope_scaling.get("type") or self.config.rope_scaling.get("rope_type", "")
335
  if scaling_type == "longrope":
336
  self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
337
+ elif not scaling_type or scaling_type == "default":
338
+ # newer transformers injects {"rope_type": "default"} when rope_scaling is absent
339
+ self.rotary_emb = PhiMoERotaryEmbedding(
340
+ self.head_dim,
341
+ max_position_embeddings=self.max_position_embeddings,
342
+ base=self.rope_theta,
343
+ )
344
  else:
345
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
346