Update modeling_neollm.py
Browse files- modeling_neollm.py +17 -34
modeling_neollm.py
CHANGED
|
@@ -4701,38 +4701,11 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4701 |
persistent=True,
|
| 4702 |
)
|
| 4703 |
|
| 4704 |
-
# ββ Non-persistent buffers (recomputed from fixed formula on load) β
|
| 4705 |
-
# RoPE cos/sin for intra-token positions 0..MAX_BYTES-1.
|
| 4706 |
-
# Shape [MAX_BYTES, d//2] β applied over the 256-type axis in
|
| 4707 |
-
# _build_rope_bytes, not over the batch/sequence axis.
|
| 4708 |
-
half = d // 2
|
| 4709 |
-
theta = 1.0 / (base ** (torch.arange(0, half, dtype=torch.float) * 2.0 / d))
|
| 4710 |
-
pos = torch.arange(self.MAX_BYTES, dtype=torch.float)
|
| 4711 |
-
freqs = torch.outer(pos, theta) # [MAX_BYTES, half]
|
| 4712 |
-
self.register_buffer("intra_cos", freqs.cos(), persistent=False)
|
| 4713 |
-
self.register_buffer("intra_sin", freqs.sin(), persistent=False)
|
| 4714 |
-
|
| 4715 |
-
# Static position index [MAX_BYTES] used as the column index in the
|
| 4716 |
-
# vocab-level gather. Registered as buffer to avoid dynamic tensor
|
| 4717 |
-
# creation inside forward (which would trigger torch.compile retracing).
|
| 4718 |
-
self.register_buffer(
|
| 4719 |
-
"pos_idx",
|
| 4720 |
-
torch.arange(self.MAX_BYTES, dtype=torch.long),
|
| 4721 |
-
persistent=False,
|
| 4722 |
-
)
|
| 4723 |
-
|
| 4724 |
# LayerNorm over character embeddings β mirrors the reference impl
|
| 4725 |
# (character_norm=True by default in littletrainingloop). Runs in
|
| 4726 |
# float32 for stability, applied at vocab level before the batch gather.
|
| 4727 |
self.char_norm = nn.LayerNorm(d)
|
| 4728 |
|
| 4729 |
-
# ββ Hook post-carga ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4730 |
-
# Los buffers non-persistent (intra_cos, intra_sin, pos_idx) se
|
| 4731 |
-
# calculan desde una fΓ³rmula fija y NUNCA deben venir del safetensors.
|
| 4732 |
-
# Si el checkpoint fue guardado con persistent=True (versiΓ³n anterior),
|
| 4733 |
-
# from_pretrained los sobreescribirΓa con valores corruptos.
|
| 4734 |
-
# _load_from_state_dict los elimina antes de aplicar el state_dict.
|
| 4735 |
-
|
| 4736 |
def _load_from_state_dict(
|
| 4737 |
self, state_dict, prefix, local_metadata, strict, missing_keys,
|
| 4738 |
unexpected_keys, error_msgs,
|
|
@@ -4800,8 +4773,11 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4800 |
byte_emb.weight. All shapes are fully static, so torch.compile can
|
| 4801 |
fuse this into a single kernel.
|
| 4802 |
|
| 4803 |
-
|
| 4804 |
-
|
|
|
|
|
|
|
|
|
|
| 4805 |
|
| 4806 |
Returns:
|
| 4807 |
rope_bytes [256, MAX_BYTES, d]
|
|
@@ -4810,8 +4786,14 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4810 |
half = w.shape[-1] // 2
|
| 4811 |
w1 = w[:, :half].unsqueeze(1) # [256, 1, half]
|
| 4812 |
w2 = w[:, half:].unsqueeze(1) # [256, 1, half]
|
| 4813 |
-
|
| 4814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4815 |
return torch.cat(
|
| 4816 |
[w1 * cos - w2 * sin,
|
| 4817 |
w1 * sin + w2 * cos],
|
|
@@ -4843,8 +4825,8 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4843 |
# vocab token and each position, the RoPE-rotated embedding of that
|
| 4844 |
# byte at that position. Result [V, MAX_BYTES, d], then sum β [V, d].
|
| 4845 |
e_chars_vocab = rope_bytes[
|
| 4846 |
-
self.token_bytes,
|
| 4847 |
-
self.
|
| 4848 |
].sum(1) # [V, d]
|
| 4849 |
|
| 4850 |
# ββ Step 3: apply precomputed 1/βbyte_len per vocab type ββββββββββββ
|
|
@@ -4894,8 +4876,9 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4894 |
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
|
| 4895 |
e_chars_vocab = rope_bytes[
|
| 4896 |
self.token_bytes,
|
| 4897 |
-
self.
|
| 4898 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
|
|
|
| 4899 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4900 |
|
| 4901 |
|
|
|
|
| 4701 |
persistent=True,
|
| 4702 |
)
|
| 4703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4704 |
# LayerNorm over character embeddings β mirrors the reference impl
|
| 4705 |
# (character_norm=True by default in littletrainingloop). Runs in
|
| 4706 |
# float32 for stability, applied at vocab level before the batch gather.
|
| 4707 |
self.char_norm = nn.LayerNorm(d)
|
| 4708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4709 |
def _load_from_state_dict(
|
| 4710 |
self, state_dict, prefix, local_metadata, strict, missing_keys,
|
| 4711 |
unexpected_keys, error_msgs,
|
|
|
|
| 4773 |
byte_emb.weight. All shapes are fully static, so torch.compile can
|
| 4774 |
fuse this into a single kernel.
|
| 4775 |
|
| 4776 |
+
cos/sin se computan inline aquΓ en vez de usar buffers registrados.
|
| 4777 |
+
Con device_map + accelerate, from_pretrained materializa tensores
|
| 4778 |
+
del safetensors directamente β non-persistent buffers que no estΓ‘n
|
| 4779 |
+
en el checkpoint quedan como memoria sin inicializar. Computar inline
|
| 4780 |
+
elimina esa dependencia sin overhead apreciable ([16, d//2] es Γnfimo).
|
| 4781 |
|
| 4782 |
Returns:
|
| 4783 |
rope_bytes [256, MAX_BYTES, d]
|
|
|
|
| 4786 |
half = w.shape[-1] // 2
|
| 4787 |
w1 = w[:, :half].unsqueeze(1) # [256, 1, half]
|
| 4788 |
w2 = w[:, half:].unsqueeze(1) # [256, 1, half]
|
| 4789 |
+
# Computar RoPE inline β formas estΓ‘ticas, torch.compile lo fusiona.
|
| 4790 |
+
theta = 1.0 / (self._rope_base ** (
|
| 4791 |
+
torch.arange(0, half, dtype=torch.float32, device=w.device) * 2.0 / (half * 2)
|
| 4792 |
+
))
|
| 4793 |
+
pos = torch.arange(self.MAX_BYTES, dtype=torch.float32, device=w.device)
|
| 4794 |
+
freqs = torch.outer(pos, theta) # [MAX_BYTES, half]
|
| 4795 |
+
cos = freqs.cos().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half]
|
| 4796 |
+
sin = freqs.sin().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half]
|
| 4797 |
return torch.cat(
|
| 4798 |
[w1 * cos - w2 * sin,
|
| 4799 |
w1 * sin + w2 * cos],
|
|
|
|
| 4825 |
# vocab token and each position, the RoPE-rotated embedding of that
|
| 4826 |
# byte at that position. Result [V, MAX_BYTES, d], then sum β [V, d].
|
| 4827 |
e_chars_vocab = rope_bytes[
|
| 4828 |
+
self.token_bytes, # [V, MAX_BYTES] β row index
|
| 4829 |
+
torch.arange(self.MAX_BYTES, device=rope_bytes.device), # [MAX_BYTES] β col index
|
| 4830 |
].sum(1) # [V, d]
|
| 4831 |
|
| 4832 |
# ββ Step 3: apply precomputed 1/βbyte_len per vocab type ββββββββββββ
|
|
|
|
| 4876 |
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
|
| 4877 |
e_chars_vocab = rope_bytes[
|
| 4878 |
self.token_bytes,
|
| 4879 |
+
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4880 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4881 |
+
e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_emb_weight.dtype)
|
| 4882 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4883 |
|
| 4884 |
|