KitsuVp commited on
Commit
e8d234f
Β·
verified Β·
1 Parent(s): e4a1d99

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +17 -34
modeling_neollm.py CHANGED
@@ -4701,38 +4701,11 @@ class SpellingBeeEmbedding(nn.Module):
4701
  persistent=True,
4702
  )
4703
 
4704
- # ── Non-persistent buffers (recomputed from fixed formula on load) ─
4705
- # RoPE cos/sin for intra-token positions 0..MAX_BYTES-1.
4706
- # Shape [MAX_BYTES, d//2] β€” applied over the 256-type axis in
4707
- # _build_rope_bytes, not over the batch/sequence axis.
4708
- half = d // 2
4709
- theta = 1.0 / (base ** (torch.arange(0, half, dtype=torch.float) * 2.0 / d))
4710
- pos = torch.arange(self.MAX_BYTES, dtype=torch.float)
4711
- freqs = torch.outer(pos, theta) # [MAX_BYTES, half]
4712
- self.register_buffer("intra_cos", freqs.cos(), persistent=False)
4713
- self.register_buffer("intra_sin", freqs.sin(), persistent=False)
4714
-
4715
- # Static position index [MAX_BYTES] used as the column index in the
4716
- # vocab-level gather. Registered as buffer to avoid dynamic tensor
4717
- # creation inside forward (which would trigger torch.compile retracing).
4718
- self.register_buffer(
4719
- "pos_idx",
4720
- torch.arange(self.MAX_BYTES, dtype=torch.long),
4721
- persistent=False,
4722
- )
4723
-
4724
  # LayerNorm over character embeddings β€” mirrors the reference impl
4725
  # (character_norm=True by default in littletrainingloop). Runs in
4726
  # float32 for stability, applied at vocab level before the batch gather.
4727
  self.char_norm = nn.LayerNorm(d)
4728
 
4729
- # ── Hook post-carga ────────────────────────────────────────────────
4730
- # Los buffers non-persistent (intra_cos, intra_sin, pos_idx) se
4731
- # calculan desde una fΓ³rmula fija y NUNCA deben venir del safetensors.
4732
- # Si el checkpoint fue guardado con persistent=True (versiΓ³n anterior),
4733
- # from_pretrained los sobreescribirΓ­a con valores corruptos.
4734
- # _load_from_state_dict los elimina antes de aplicar el state_dict.
4735
-
4736
  def _load_from_state_dict(
4737
  self, state_dict, prefix, local_metadata, strict, missing_keys,
4738
  unexpected_keys, error_msgs,
@@ -4800,8 +4773,11 @@ class SpellingBeeEmbedding(nn.Module):
4800
  byte_emb.weight. All shapes are fully static, so torch.compile can
4801
  fuse this into a single kernel.
4802
 
4803
- Called once per forward pass; the result is discarded afterward.
4804
- The cost is two broadcast elementwise ops + one cat over fixed dims.
 
 
 
4805
 
4806
  Returns:
4807
  rope_bytes [256, MAX_BYTES, d]
@@ -4810,8 +4786,14 @@ class SpellingBeeEmbedding(nn.Module):
4810
  half = w.shape[-1] // 2
4811
  w1 = w[:, :half].unsqueeze(1) # [256, 1, half]
4812
  w2 = w[:, half:].unsqueeze(1) # [256, 1, half]
4813
- cos = self.intra_cos.unsqueeze(0) # [1, MAX_BYTES, half]
4814
- sin = self.intra_sin.unsqueeze(0) # [1, MAX_BYTES, half]
 
 
 
 
 
 
4815
  return torch.cat(
4816
  [w1 * cos - w2 * sin,
4817
  w1 * sin + w2 * cos],
@@ -4843,8 +4825,8 @@ class SpellingBeeEmbedding(nn.Module):
4843
  # vocab token and each position, the RoPE-rotated embedding of that
4844
  # byte at that position. Result [V, MAX_BYTES, d], then sum β†’ [V, d].
4845
  e_chars_vocab = rope_bytes[
4846
- self.token_bytes, # [V, MAX_BYTES] β€” row index
4847
- self.pos_idx.unsqueeze(0), # [1, MAX_BYTES] β†’ broadcast [V, MAX_BYTES]
4848
  ].sum(1) # [V, d]
4849
 
4850
  # ── Step 3: apply precomputed 1/√byte_len per vocab type ────────────
@@ -4894,8 +4876,9 @@ class SpellingBeeEmbedding(nn.Module):
4894
  rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
4895
  e_chars_vocab = rope_bytes[
4896
  self.token_bytes,
4897
- self.pos_idx.unsqueeze(0),
4898
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
 
4899
  return (token_emb_weight + e_chars_vocab) * 0.5
4900
 
4901
 
 
4701
  persistent=True,
4702
  )
4703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4704
  # LayerNorm over character embeddings β€” mirrors the reference impl
4705
  # (character_norm=True by default in littletrainingloop). Runs in
4706
  # float32 for stability, applied at vocab level before the batch gather.
4707
  self.char_norm = nn.LayerNorm(d)
4708
 
 
 
 
 
 
 
 
4709
  def _load_from_state_dict(
4710
  self, state_dict, prefix, local_metadata, strict, missing_keys,
4711
  unexpected_keys, error_msgs,
 
4773
  byte_emb.weight. All shapes are fully static, so torch.compile can
4774
  fuse this into a single kernel.
4775
 
4776
+ cos/sin se computan inline aquΓ­ en vez de usar buffers registrados.
4777
+ Con device_map + accelerate, from_pretrained materializa tensores
4778
+ del safetensors directamente β€” non-persistent buffers que no estΓ‘n
4779
+ en el checkpoint quedan como memoria sin inicializar. Computar inline
4780
+ elimina esa dependencia sin overhead apreciable ([16, d//2] es Γ­nfimo).
4781
 
4782
  Returns:
4783
  rope_bytes [256, MAX_BYTES, d]
 
4786
  half = w.shape[-1] // 2
4787
  w1 = w[:, :half].unsqueeze(1) # [256, 1, half]
4788
  w2 = w[:, half:].unsqueeze(1) # [256, 1, half]
4789
+ # Computar RoPE inline β€” formas estΓ‘ticas, torch.compile lo fusiona.
4790
+ theta = 1.0 / (self._rope_base ** (
4791
+ torch.arange(0, half, dtype=torch.float32, device=w.device) * 2.0 / (half * 2)
4792
+ ))
4793
+ pos = torch.arange(self.MAX_BYTES, dtype=torch.float32, device=w.device)
4794
+ freqs = torch.outer(pos, theta) # [MAX_BYTES, half]
4795
+ cos = freqs.cos().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half]
4796
+ sin = freqs.sin().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half]
4797
  return torch.cat(
4798
  [w1 * cos - w2 * sin,
4799
  w1 * sin + w2 * cos],
 
4825
  # vocab token and each position, the RoPE-rotated embedding of that
4826
  # byte at that position. Result [V, MAX_BYTES, d], then sum β†’ [V, d].
4827
  e_chars_vocab = rope_bytes[
4828
+ self.token_bytes, # [V, MAX_BYTES] β€” row index
4829
+ torch.arange(self.MAX_BYTES, device=rope_bytes.device), # [MAX_BYTES] β€” col index
4830
  ].sum(1) # [V, d]
4831
 
4832
  # ── Step 3: apply precomputed 1/√byte_len per vocab type ────────────
 
4876
  rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
4877
  e_chars_vocab = rope_bytes[
4878
  self.token_bytes,
4879
+ torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4880
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4881
+ e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_emb_weight.dtype)
4882
  return (token_emb_weight + e_chars_vocab) * 0.5
4883
 
4884