KitsuVp commited on
Commit
e4a1d99
verified
1 Parent(s): b299097

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +19 -13
modeling_neollm.py CHANGED
@@ -4731,22 +4731,28 @@ class SpellingBeeEmbedding(nn.Module):
4731
  # calculan desde una f贸rmula fija y NUNCA deben venir del safetensors.
4732
  # Si el checkpoint fue guardado con persistent=True (versi贸n anterior),
4733
  # from_pretrained los sobreescribir铆a con valores corruptos.
4734
- # Este hook los elimina del state_dict entrante antes de la carga.
4735
- self._register_load_state_dict_pre_hook(self._reset_rope_buffers_hook)
4736
 
4737
- def _reset_rope_buffers_hook(self, state_dict, prefix, *args, **kwargs):
 
 
 
4738
  """
4739
- Pre-hook de carga: elimina intra_cos, intra_sin y pos_idx del
4740
- state_dict entrante para que no sobreescriban los valores correctos
4741
- calculados en __init__.
4742
-
4743
- Necesario porque versiones anteriores los guardaban como persistent=True,
4744
- dejando valores corruptos en el safetensors que from_pretrained cargaba
4745
- silenciosamente, corrompiendo los buffers tras cada load.
 
4746
  """
4747
- for key in [f"{prefix}intra_cos", f"{prefix}intra_sin", f"{prefix}pos_idx"]:
4748
- if key in state_dict:
4749
- del state_dict[key]
 
 
 
4750
 
4751
  def set_byte_table(self, tokenizer) -> None:
4752
  """
 
4731
  # calculan desde una f贸rmula fija y NUNCA deben venir del safetensors.
4732
  # Si el checkpoint fue guardado con persistent=True (versi贸n anterior),
4733
  # from_pretrained los sobreescribir铆a con valores corruptos.
4734
+ # _load_from_state_dict los elimina antes de aplicar el state_dict.
 
4735
 
4736
+ def _load_from_state_dict(
4737
+ self, state_dict, prefix, local_metadata, strict, missing_keys,
4738
+ unexpected_keys, error_msgs,
4739
+ ):
4740
  """
4741
+ Sobreescribe la carga de state_dict para eliminar los tres buffers
4742
+ non-persistent (intra_cos, intra_sin, pos_idx) antes de aplicar el
4743
+ state_dict, evitando que versiones anteriores del checkpoint
4744
+ (donde eran persistent=True) sobreescriban los valores correctos
4745
+ calculados en __init__ con valores corruptos del safetensors.
4746
+
4747
+ from_pretrained de HuggingFace bypasea _register_load_state_dict_pre_hook
4748
+ y carga directamente por nombre, por lo que este override es necesario.
4749
  """
4750
+ for key in ("intra_cos", "intra_sin", "pos_idx"):
4751
+ state_dict.pop(prefix + key, None)
4752
+ super()._load_from_state_dict(
4753
+ state_dict, prefix, local_metadata, strict,
4754
+ missing_keys, unexpected_keys, error_msgs,
4755
+ )
4756
 
4757
  def set_byte_table(self, tokenizer) -> None:
4758
  """