harness / diffs /35615.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 8eb2d7439ef3..573a9322efd1 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -4336,26 +4336,31 @@ def from_pretrained(
return model
@staticmethod
- def _fix_state_dict_key_on_load(key):
+ def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
- if "beta" in key:
- return key.replace("beta", "bias")
- if "gamma" in key:
- return key.replace("gamma", "weight")
+ # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
+ # This rename is logged.
+ if key.endswith("LayerNorm.beta"):
+ return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
+ if key.endswith("LayerNorm.gamma"):
+ return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
- # to avoid logging parametrized weight norm renaming
+ # Rename weight norm parametrizations to match changes across torch versions.
+ # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
+ # This rename is not logged.
if hasattr(nn.utils.parametrizations, "weight_norm"):
- if "weight_g" in key:
- return key.replace("weight_g", "parametrizations.weight.original0")
- if "weight_v" in key:
- return key.replace("weight_v", "parametrizations.weight.original1")
+ if key.endswith("weight_g"):
+ return key.replace("weight_g", "parametrizations.weight.original0"), True
+ if key.endswith("weight_v"):
+ return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
- if "parametrizations.weight.original0" in key:
- return key.replace("parametrizations.weight.original0", "weight_g")
- if "parametrizations.weight.original1" in key:
- return key.replace("parametrizations.weight.original1", "weight_v")
- return key
+ if key.endswith("parametrizations.weight.original0"):
+ return key.replace("parametrizations.weight.original0", "weight_g"), True
+ if key.endswith("parametrizations.weight.original1"):
+ return key.replace("parametrizations.weight.original1", "weight_v"), True
+
+ return key, False
@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
@@ -4366,15 +4371,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
- new_key = cls._fix_state_dict_key_on_load(key)
- if new_key != key:
+ new_key, has_changed = cls._fix_state_dict_key_on_load(key)
+ if has_changed:
state_dict[new_key] = state_dict.pop(key)
- # add it once for logging
- if "gamma" in key and "gamma" not in renamed_keys:
- renamed_keys["gamma"] = (key, new_key)
- if "beta" in key and "beta" not in renamed_keys:
- renamed_keys["beta"] = (key, new_key)
+ # track gamma/beta rename for logging
+ if key.endswith("LayerNorm.gamma"):
+ renamed_keys["LayerNorm.gamma"] = (key, new_key)
+ elif key.endswith("LayerNorm.beta"):
+ renamed_keys["LayerNorm.beta"] = (key, new_key)
if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
@@ -4387,19 +4392,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
return state_dict
@staticmethod
- def _fix_state_dict_key_on_save(key):
+ def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
- Do nothing by default, but can be overriden in particular models.
+ Do nothing by default, but can be overridden in particular models.
"""
- return key
+ return key, False
def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
- return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
+ return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}
@classmethod
def _load_pretrained_model(
@@ -4457,7 +4462,7 @@ def _load_pretrained_model(
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
original_loaded_keys = loaded_keys
- loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
+ loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
index 47e8944583b4..a74202ce5aa5 100644
--- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
+++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
@@ -90,22 +90,22 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
- def _fix_state_dict_key_on_load(key):
+ def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
- return f"timm_model.{key}"
- return key
+ return f"timm_model.{key}", True
+ return key, False
def _fix_state_dict_key_on_save(self, key):
"""
Overrides original method to remove "timm_model." prefix from state_dict keys.
Makes the saved checkpoint compatible with the `timm` library.
"""
- return key.replace("timm_model.", "")
+ return key.replace("timm_model.", ""), True
def load_state_dict(self, state_dict, *args, **kwargs):
"""
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 383f0cbe60e1..e90f8aa7d039 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -1562,57 +1562,47 @@ def test_model_from_pretrained_from_mlx(self):
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
def test_warning_for_beta_gamma_parameters(self):
- class TestModelGamma(PreTrainedModel):
+ class TestGammaBetaNorm(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gamma = torch.nn.Parameter(torch.ones(1))
+ self.beta = torch.nn.Parameter(torch.zeros(1))
+
+ def forward(self):
+ return self.gamma.sum() + self.beta.sum()
+
+ class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.gamma_param = nn.Parameter(torch.ones(10))
+ self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
- return self.gamma_param.sum()
+ return self.LayerNorm()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
- warning_msg_gamma = "`gamma_param` -> `weight_param`"
- model = TestModelGamma(config)
+ warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
+ warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`"
+ model = TestModelGammaBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
- _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
+ _, loading_info = TestModelGammaBeta.from_pretrained(
+ tmp_dir, config=config, output_loading_info=True
+ )
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
- self.assertIn("`TestModelGamma`", cl1.out)
+ self.assertIn("`TestModelGammaBeta`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
- self.assertIn("gamma_param", missing_keys)
- self.assertIn("weight_param", unexpected_keys)
-
- class TestModelBeta(PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.beta_param = nn.Parameter(torch.ones(10))
- self.post_init()
-
- def forward(self):
- return self.beta_param.sum()
-
- warning_msg_beta = "`beta_param` -> `bias_param`"
- model = TestModelBeta(config)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(tmp_dir)
- with LoggingLevel(logging.INFO):
- with CaptureLogger(logger) as cl2:
- _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
-
- missing_keys = loading_info["missing_keys"]
- unexpected_keys = loading_info["unexpected_keys"]
- self.assertIn("`TestModelBeta`", cl2.out)
- self.assertIn(warning_msg_beta, cl2.out)
- self.assertIn("beta_param", missing_keys)
- self.assertIn("bias_param", unexpected_keys)
+ self.assertIn(warning_msg_beta, cl1.out)
+ self.assertIn("LayerNorm.gamma", missing_keys)
+ self.assertIn("LayerNorm.weight", unexpected_keys)
+ self.assertIn("LayerNorm.beta", missing_keys)
+ self.assertIn("LayerNorm.bias", unexpected_keys)
def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""