| |
| |
| |
| |
| @@ -4336,26 +4336,31 @@ def from_pretrained( |
| return model |
| |
| @staticmethod |
| - def _fix_state_dict_key_on_load(key): |
| + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: |
| """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" |
| |
| - if "beta" in key: |
| - return key.replace("beta", "bias") |
| - if "gamma" in key: |
| - return key.replace("gamma", "weight") |
| + # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) |
| + # This rename is logged. |
| + if key.endswith("LayerNorm.beta"): |
| + return key.replace("LayerNorm.beta", "LayerNorm.bias"), True |
| + if key.endswith("LayerNorm.gamma"): |
| + return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True |
| |
| - # to avoid logging parametrized weight norm renaming |
| + # Rename weight norm parametrizations to match changes across torch versions. |
| + # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others. |
| + # This rename is not logged. |
| if hasattr(nn.utils.parametrizations, "weight_norm"): |
| - if "weight_g" in key: |
| - return key.replace("weight_g", "parametrizations.weight.original0") |
| - if "weight_v" in key: |
| - return key.replace("weight_v", "parametrizations.weight.original1") |
| + if key.endswith("weight_g"): |
| + return key.replace("weight_g", "parametrizations.weight.original0"), True |
| + if key.endswith("weight_v"): |
| + return key.replace("weight_v", "parametrizations.weight.original1"), True |
| else: |
| - if "parametrizations.weight.original0" in key: |
| - return key.replace("parametrizations.weight.original0", "weight_g") |
| - if "parametrizations.weight.original1" in key: |
| - return key.replace("parametrizations.weight.original1", "weight_v") |
| - return key |
| + if key.endswith("parametrizations.weight.original0"): |
| + return key.replace("parametrizations.weight.original0", "weight_g"), True |
| + if key.endswith("parametrizations.weight.original1"): |
| + return key.replace("parametrizations.weight.original1", "weight_v"), True |
| + |
| + return key, False |
| |
| @classmethod |
| def _fix_state_dict_keys_on_load(cls, state_dict): |
| @@ -4366,15 +4371,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict): |
| renamed_keys = {} |
| state_dict_keys = list(state_dict.keys()) |
| for key in state_dict_keys: |
| - new_key = cls._fix_state_dict_key_on_load(key) |
| - if new_key != key: |
| + new_key, has_changed = cls._fix_state_dict_key_on_load(key) |
| + if has_changed: |
| state_dict[new_key] = state_dict.pop(key) |
| |
| - # add it once for logging |
| - if "gamma" in key and "gamma" not in renamed_keys: |
| - renamed_keys["gamma"] = (key, new_key) |
| - if "beta" in key and "beta" not in renamed_keys: |
| - renamed_keys["beta"] = (key, new_key) |
| + # track gamma/beta rename for logging |
| + if key.endswith("LayerNorm.gamma"): |
| + renamed_keys["LayerNorm.gamma"] = (key, new_key) |
| + elif key.endswith("LayerNorm.beta"): |
| + renamed_keys["LayerNorm.beta"] = (key, new_key) |
| |
| if renamed_keys: |
| warning_msg = f"A pretrained model of type `{cls.__name__}` " |
| @@ -4387,19 +4392,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict): |
| return state_dict |
| |
| @staticmethod |
| - def _fix_state_dict_key_on_save(key): |
| + def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: |
| """ |
| Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. |
| - Do nothing by default, but can be overriden in particular models. |
| + Do nothing by default, but can be overridden in particular models. |
| """ |
| - return key |
| + return key, False |
| |
| def _fix_state_dict_keys_on_save(self, state_dict): |
| """ |
| Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. |
| Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. |
| """ |
| - return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} |
| + return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()} |
| |
| @classmethod |
| def _load_pretrained_model( |
| @@ -4457,7 +4462,7 @@ def _load_pretrained_model( |
| expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) |
| |
| original_loaded_keys = loaded_keys |
| - loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys] |
| + loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys] |
| |
| if len(prefix) > 0: |
| has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) |
| |
| |
| |
| |
| @@ -90,22 +90,22 @@ def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| |
| @staticmethod |
| - def _fix_state_dict_key_on_load(key): |
| + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: |
| """ |
| Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. |
| We don't want this behavior for timm wrapped models. Instead, this method adds a |
| "timm_model." prefix to enable loading official timm Hub checkpoints. |
| """ |
| if "timm_model." not in key: |
| - return f"timm_model.{key}" |
| - return key |
| + return f"timm_model.{key}", True |
| + return key, False |
| |
| def _fix_state_dict_key_on_save(self, key): |
| """ |
| Overrides original method to remove "timm_model." prefix from state_dict keys. |
| Makes the saved checkpoint compatible with the `timm` library. |
| """ |
| - return key.replace("timm_model.", "") |
| + return key.replace("timm_model.", ""), True |
| |
| def load_state_dict(self, state_dict, *args, **kwargs): |
| """ |
| |
| |
| |
| |
| @@ -1562,57 +1562,47 @@ def test_model_from_pretrained_from_mlx(self): |
| self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"])) |
| |
| def test_warning_for_beta_gamma_parameters(self): |
| - class TestModelGamma(PreTrainedModel): |
| + class TestGammaBetaNorm(torch.nn.Module): |
| + def __init__(self): |
| + super().__init__() |
| + self.gamma = torch.nn.Parameter(torch.ones(1)) |
| + self.beta = torch.nn.Parameter(torch.zeros(1)) |
| + |
| + def forward(self): |
| + return self.gamma.sum() + self.beta.sum() |
| + |
| + class TestModelGammaBeta(PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| - self.gamma_param = nn.Parameter(torch.ones(10)) |
| + self.LayerNorm = TestGammaBetaNorm() |
| self.post_init() |
| |
| def forward(self): |
| - return self.gamma_param.sum() |
| + return self.LayerNorm() |
| |
| logger = logging.get_logger("transformers.modeling_utils") |
| config = PretrainedConfig() |
| - warning_msg_gamma = "`gamma_param` -> `weight_param`" |
| - model = TestModelGamma(config) |
| + warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`" |
| + warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`" |
| + model = TestModelGammaBeta(config) |
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| model.save_pretrained(tmp_dir) |
| with LoggingLevel(logging.INFO): |
| with CaptureLogger(logger) as cl1: |
| - _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) |
| + _, loading_info = TestModelGammaBeta.from_pretrained( |
| + tmp_dir, config=config, output_loading_info=True |
| + ) |
| |
| missing_keys = loading_info["missing_keys"] |
| unexpected_keys = loading_info["unexpected_keys"] |
| - self.assertIn("`TestModelGamma`", cl1.out) |
| + self.assertIn("`TestModelGammaBeta`", cl1.out) |
| self.assertIn(warning_msg_gamma, cl1.out) |
| - self.assertIn("gamma_param", missing_keys) |
| - self.assertIn("weight_param", unexpected_keys) |
| - |
| - class TestModelBeta(PreTrainedModel): |
| - def __init__(self, config): |
| - super().__init__(config) |
| - self.beta_param = nn.Parameter(torch.ones(10)) |
| - self.post_init() |
| - |
| - def forward(self): |
| - return self.beta_param.sum() |
| - |
| - warning_msg_beta = "`beta_param` -> `bias_param`" |
| - model = TestModelBeta(config) |
| - |
| - with tempfile.TemporaryDirectory() as tmp_dir: |
| - model.save_pretrained(tmp_dir) |
| - with LoggingLevel(logging.INFO): |
| - with CaptureLogger(logger) as cl2: |
| - _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) |
| - |
| - missing_keys = loading_info["missing_keys"] |
| - unexpected_keys = loading_info["unexpected_keys"] |
| - self.assertIn("`TestModelBeta`", cl2.out) |
| - self.assertIn(warning_msg_beta, cl2.out) |
| - self.assertIn("beta_param", missing_keys) |
| - self.assertIn("bias_param", unexpected_keys) |
| + self.assertIn(warning_msg_beta, cl1.out) |
| + self.assertIn("LayerNorm.gamma", missing_keys) |
| + self.assertIn("LayerNorm.weight", unexpected_keys) |
| + self.assertIn("LayerNorm.beta", missing_keys) |
| + self.assertIn("LayerNorm.bias", unexpected_keys) |
| |
| def test_isin_mps_friendly(self): |
| """tests that our custom `isin_mps_friendly` matches `torch.isin`""" |
|
|