KitsuVp commited on
Commit
78c5358
Β·
verified Β·
1 Parent(s): 0d7f6d3

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +3 -3
modeling_neollm.py CHANGED
@@ -4836,8 +4836,8 @@ class SpellingBeeEmbedding(nn.Module):
4836
  # ── Step 3b: LayerNorm over character embeddings (float32 for stability)
4837
  # Mirrors reference impl (character_norm=True). Ensures E[β€–e_charsβ€–Β²]
4838
  # matches E[β€–e_tokβ€–Β²] regardless of token byte-length distribution.
4839
- e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_embeds.dtype)
4840
-
4841
  # ── Step 4: gather only the tokens present in this batch ────────────
4842
  # This is the only BΓ—S operation β€” a single embedding lookup.
4843
  e_chars = e_chars_vocab[token_ids] # [B, S, d] or [N, d]
@@ -4878,7 +4878,7 @@ class SpellingBeeEmbedding(nn.Module):
4878
  self.token_bytes,
4879
  torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4880
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4881
- e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_emb_weight.dtype)
4882
  return (token_emb_weight + e_chars_vocab) * 0.5
4883
 
4884
 
 
4836
  # ── Step 3b: LayerNorm over character embeddings (float32 for stability)
4837
  # Mirrors reference impl (character_norm=True). Ensures E[β€–e_charsβ€–Β²]
4838
  # matches E[β€–e_tokβ€–Β²] regardless of token byte-length distribution.
4839
+ e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_embeds.dtype)
4840
+
4841
  # ── Step 4: gather only the tokens present in this batch ────────────
4842
  # This is the only BΓ—S operation β€” a single embedding lookup.
4843
  e_chars = e_chars_vocab[token_ids] # [B, S, d] or [N, d]
 
4878
  self.token_bytes,
4879
  torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4880
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4881
+ e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_emb_weight.dtype)
4882
  return (token_emb_weight + e_chars_vocab) * 0.5
4883
 
4884