Update modeling_neollm.py
Browse files- modeling_neollm.py +2 -2
modeling_neollm.py
CHANGED
|
@@ -4836,7 +4836,7 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4836 |
# ββ Step 3b: LayerNorm over character embeddings (float32 for stability)
|
| 4837 |
# Mirrors reference impl (character_norm=True). Ensures E[βe_charsβΒ²]
|
| 4838 |
# matches E[βe_tokβΒ²] regardless of token byte-length distribution.
|
| 4839 |
-
e_chars_vocab = self.char_norm(e_chars_vocab.
|
| 4840 |
|
| 4841 |
# ββ Step 4: gather only the tokens present in this batch ββββββββββββ
|
| 4842 |
# This is the only BΓS operation β a single embedding lookup.
|
|
@@ -4878,7 +4878,7 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4878 |
self.token_bytes,
|
| 4879 |
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4880 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4881 |
-
e_chars_vocab = self.char_norm(e_chars_vocab.
|
| 4882 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4883 |
|
| 4884 |
|
|
|
|
| 4836 |
# ββ Step 3b: LayerNorm over character embeddings (float32 for stability)
|
| 4837 |
# Mirrors reference impl (character_norm=True). Ensures E[βe_charsβΒ²]
|
| 4838 |
# matches E[βe_tokβΒ²] regardless of token byte-length distribution.
|
| 4839 |
+
e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_embeds.dtype)
|
| 4840 |
|
| 4841 |
# ββ Step 4: gather only the tokens present in this batch ββββββββββββ
|
| 4842 |
# This is the only BΓS operation β a single embedding lookup.
|
|
|
|
| 4878 |
self.token_bytes,
|
| 4879 |
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4880 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4881 |
+
e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_emb_weight.dtype)
|
| 4882 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4883 |
|
| 4884 |
|