KitsuVp commited on
Commit
b299097
·
verified ·
1 Parent(s): 61f6640

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +25 -0
modeling_neollm.py CHANGED
@@ -4679,6 +4679,9 @@ class SpellingBeeEmbedding(nn.Module):
4679
  d = config.hidden_size
4680
  base = getattr(config, "rope_theta", 10000.0)
4681
 
 
 
 
4682
  # 256 × d byte embedding lookup (one per UTF-8 byte value 0..255).
4683
  self.byte_emb = nn.Embedding(256, d)
4684
 
@@ -4723,6 +4726,28 @@ class SpellingBeeEmbedding(nn.Module):
4723
  # float32 for stability, applied at vocab level before the batch gather.
4724
  self.char_norm = nn.LayerNorm(d)
4725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4726
  def set_byte_table(self, tokenizer) -> None:
4727
  """
4728
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
 
4679
  d = config.hidden_size
4680
  base = getattr(config, "rope_theta", 10000.0)
4681
 
4682
+ # Guardado para poder recomputar los buffers RoPE en _reset_rope_buffers.
4683
+ self._rope_base = base
4684
+
4685
  # 256 × d byte embedding lookup (one per UTF-8 byte value 0..255).
4686
  self.byte_emb = nn.Embedding(256, d)
4687
 
 
4726
  # float32 for stability, applied at vocab level before the batch gather.
4727
  self.char_norm = nn.LayerNorm(d)
4728
 
4729
+ # ── Hook post-carga ────────────────────────────────────────────────
4730
+ # Los buffers non-persistent (intra_cos, intra_sin, pos_idx) se
4731
+ # calculan desde una fórmula fija y NUNCA deben venir del safetensors.
4732
+ # Si el checkpoint fue guardado con persistent=True (versión anterior),
4733
+ # from_pretrained los sobreescribiría con valores corruptos.
4734
+ # Este hook los elimina del state_dict entrante antes de la carga.
4735
+ self._register_load_state_dict_pre_hook(self._reset_rope_buffers_hook)
4736
+
4737
+ def _reset_rope_buffers_hook(self, state_dict, prefix, *args, **kwargs):
4738
+ """
4739
+ Pre-hook de carga: elimina intra_cos, intra_sin y pos_idx del
4740
+ state_dict entrante para que no sobreescriban los valores correctos
4741
+ calculados en __init__.
4742
+
4743
+ Necesario porque versiones anteriores los guardaban como persistent=True,
4744
+ dejando valores corruptos en el safetensors que from_pretrained cargaba
4745
+ silenciosamente, corrompiendo los buffers tras cada load.
4746
+ """
4747
+ for key in [f"{prefix}intra_cos", f"{prefix}intra_sin", f"{prefix}pos_idx"]:
4748
+ if key in state_dict:
4749
+ del state_dict[key]
4750
+
4751
  def set_byte_table(self, tokenizer) -> None:
4752
  """
4753
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.