Update modeling_neollm.py
Browse files- modeling_neollm.py +25 -0
modeling_neollm.py
CHANGED
|
@@ -4679,6 +4679,9 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4679 |
d = config.hidden_size
|
| 4680 |
base = getattr(config, "rope_theta", 10000.0)
|
| 4681 |
|
|
|
|
|
|
|
|
|
|
| 4682 |
# 256 × d byte embedding lookup (one per UTF-8 byte value 0..255).
|
| 4683 |
self.byte_emb = nn.Embedding(256, d)
|
| 4684 |
|
|
@@ -4723,6 +4726,28 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4723 |
# float32 for stability, applied at vocab level before the batch gather.
|
| 4724 |
self.char_norm = nn.LayerNorm(d)
|
| 4725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4726 |
def set_byte_table(self, tokenizer) -> None:
|
| 4727 |
"""
|
| 4728 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|
|
|
|
| 4679 |
d = config.hidden_size
|
| 4680 |
base = getattr(config, "rope_theta", 10000.0)
|
| 4681 |
|
| 4682 |
+
# Guardado para poder recomputar los buffers RoPE en _reset_rope_buffers.
|
| 4683 |
+
self._rope_base = base
|
| 4684 |
+
|
| 4685 |
# 256 × d byte embedding lookup (one per UTF-8 byte value 0..255).
|
| 4686 |
self.byte_emb = nn.Embedding(256, d)
|
| 4687 |
|
|
|
|
| 4726 |
# float32 for stability, applied at vocab level before the batch gather.
|
| 4727 |
self.char_norm = nn.LayerNorm(d)
|
| 4728 |
|
| 4729 |
+
# ── Hook post-carga ────────────────────────────────────────────────
|
| 4730 |
+
# Los buffers non-persistent (intra_cos, intra_sin, pos_idx) se
|
| 4731 |
+
# calculan desde una fórmula fija y NUNCA deben venir del safetensors.
|
| 4732 |
+
# Si el checkpoint fue guardado con persistent=True (versión anterior),
|
| 4733 |
+
# from_pretrained los sobreescribiría con valores corruptos.
|
| 4734 |
+
# Este hook los elimina del state_dict entrante antes de la carga.
|
| 4735 |
+
self._register_load_state_dict_pre_hook(self._reset_rope_buffers_hook)
|
| 4736 |
+
|
| 4737 |
+
def _reset_rope_buffers_hook(self, state_dict, prefix, *args, **kwargs):
|
| 4738 |
+
"""
|
| 4739 |
+
Pre-hook de carga: elimina intra_cos, intra_sin y pos_idx del
|
| 4740 |
+
state_dict entrante para que no sobreescriban los valores correctos
|
| 4741 |
+
calculados en __init__.
|
| 4742 |
+
|
| 4743 |
+
Necesario porque versiones anteriores los guardaban como persistent=True,
|
| 4744 |
+
dejando valores corruptos en el safetensors que from_pretrained cargaba
|
| 4745 |
+
silenciosamente, corrompiendo los buffers tras cada load.
|
| 4746 |
+
"""
|
| 4747 |
+
for key in [f"{prefix}intra_cos", f"{prefix}intra_sin", f"{prefix}pos_idx"]:
|
| 4748 |
+
if key in state_dict:
|
| 4749 |
+
del state_dict[key]
|
| 4750 |
+
|
| 4751 |
def set_byte_table(self, tokenizer) -> None:
|
| 4752 |
"""
|
| 4753 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|