Port to newer version of transformers package
#7
by
man2machine - opened
- configuration_slimmoe.py +1 -1
- modeling_slimmoe.py +12 -6
configuration_slimmoe.py
CHANGED
|
@@ -111,7 +111,7 @@ class PhiMoEConfig(PretrainedConfig):
|
|
| 111 |
>>> configuration = model.config
|
| 112 |
```"""
|
| 113 |
|
| 114 |
-
model_type = "phimoe"
|
| 115 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 116 |
|
| 117 |
def __init__(
|
|
|
|
| 111 |
>>> configuration = model.config
|
| 112 |
```"""
|
| 113 |
|
| 114 |
+
model_type = "phimoe_slim" # renamed from "phimoe" to bypass transformers >=4.46 conversion_mapping
|
| 115 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 116 |
|
| 117 |
def __init__(
|
modeling_slimmoe.py
CHANGED
|
@@ -46,7 +46,6 @@ from transformers.utils import (
|
|
| 46 |
logging,
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
| 49 |
-
from transformers.utils.import_utils import is_torch_fx_available
|
| 50 |
from .configuration_slimmoe import PhiMoEConfig
|
| 51 |
|
| 52 |
from einops import rearrange
|
|
@@ -61,11 +60,10 @@ if is_flash_attn_2_available():
|
|
| 61 |
|
| 62 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 63 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 64 |
-
if
|
| 65 |
-
|
| 66 |
-
import torch.fx
|
| 67 |
|
| 68 |
-
|
| 69 |
|
| 70 |
|
| 71 |
logger = logging.get_logger(__name__)
|
|
@@ -332,9 +330,17 @@ class PhiMoEAttention(nn.Module):
|
|
| 332 |
base=self.rope_theta,
|
| 333 |
)
|
| 334 |
else:
|
| 335 |
-
|
|
|
|
| 336 |
if scaling_type == "longrope":
|
| 337 |
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
else:
|
| 339 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 340 |
|
|
|
|
| 46 |
logging,
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
|
|
|
| 49 |
from .configuration_slimmoe import PhiMoEConfig
|
| 50 |
|
| 51 |
from einops import rearrange
|
|
|
|
| 60 |
|
| 61 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 62 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 63 |
+
if not is_torch_greater_or_equal_than_1_13:
|
| 64 |
+
import torch.fx
|
|
|
|
| 65 |
|
| 66 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 67 |
|
| 68 |
|
| 69 |
logger = logging.get_logger(__name__)
|
|
|
|
| 330 |
base=self.rope_theta,
|
| 331 |
)
|
| 332 |
else:
|
| 333 |
+
# "type" key was renamed to "rope_type" in transformers >=4.46; handle both
|
| 334 |
+
scaling_type = self.config.rope_scaling.get("type") or self.config.rope_scaling.get("rope_type", "")
|
| 335 |
if scaling_type == "longrope":
|
| 336 |
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
|
| 337 |
+
elif not scaling_type or scaling_type == "default":
|
| 338 |
+
# newer transformers injects {"rope_type": "default"} when rope_scaling is absent
|
| 339 |
+
self.rotary_emb = PhiMoERotaryEmbedding(
|
| 340 |
+
self.head_dim,
|
| 341 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 342 |
+
base=self.rope_theta,
|
| 343 |
+
)
|
| 344 |
else:
|
| 345 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 346 |
|