Update modeling_neollm.py
Browse files- modeling_neollm.py +19 -13
modeling_neollm.py
CHANGED
|
@@ -4731,22 +4731,28 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4731 |
# calculan desde una f贸rmula fija y NUNCA deben venir del safetensors.
|
| 4732 |
# Si el checkpoint fue guardado con persistent=True (versi贸n anterior),
|
| 4733 |
# from_pretrained los sobreescribir铆a con valores corruptos.
|
| 4734 |
-
#
|
| 4735 |
-
self._register_load_state_dict_pre_hook(self._reset_rope_buffers_hook)
|
| 4736 |
|
| 4737 |
-
def
|
|
|
|
|
|
|
|
|
|
| 4738 |
"""
|
| 4739 |
-
|
| 4740 |
-
|
| 4741 |
-
|
| 4742 |
-
|
| 4743 |
-
|
| 4744 |
-
|
| 4745 |
-
|
|
|
|
| 4746 |
"""
|
| 4747 |
-
for key in
|
| 4748 |
-
|
| 4749 |
-
|
|
|
|
|
|
|
|
|
|
| 4750 |
|
| 4751 |
def set_byte_table(self, tokenizer) -> None:
|
| 4752 |
"""
|
|
|
|
| 4731 |
# calculan desde una f贸rmula fija y NUNCA deben venir del safetensors.
|
| 4732 |
# Si el checkpoint fue guardado con persistent=True (versi贸n anterior),
|
| 4733 |
# from_pretrained los sobreescribir铆a con valores corruptos.
|
| 4734 |
+
# _load_from_state_dict los elimina antes de aplicar el state_dict.
|
|
|
|
| 4735 |
|
| 4736 |
+
def _load_from_state_dict(
|
| 4737 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys,
|
| 4738 |
+
unexpected_keys, error_msgs,
|
| 4739 |
+
):
|
| 4740 |
"""
|
| 4741 |
+
Sobreescribe la carga de state_dict para eliminar los tres buffers
|
| 4742 |
+
non-persistent (intra_cos, intra_sin, pos_idx) antes de aplicar el
|
| 4743 |
+
state_dict, evitando que versiones anteriores del checkpoint
|
| 4744 |
+
(donde eran persistent=True) sobreescriban los valores correctos
|
| 4745 |
+
calculados en __init__ con valores corruptos del safetensors.
|
| 4746 |
+
|
| 4747 |
+
from_pretrained de HuggingFace bypasea _register_load_state_dict_pre_hook
|
| 4748 |
+
y carga directamente por nombre, por lo que este override es necesario.
|
| 4749 |
"""
|
| 4750 |
+
for key in ("intra_cos", "intra_sin", "pos_idx"):
|
| 4751 |
+
state_dict.pop(prefix + key, None)
|
| 4752 |
+
super()._load_from_state_dict(
|
| 4753 |
+
state_dict, prefix, local_metadata, strict,
|
| 4754 |
+
missing_keys, unexpected_keys, error_msgs,
|
| 4755 |
+
)
|
| 4756 |
|
| 4757 |
def set_byte_table(self, tokenizer) -> None:
|
| 4758 |
"""
|