Update modeling_neollm.py
Browse files- modeling_neollm.py +5 -47
modeling_neollm.py
CHANGED
|
@@ -4727,22 +4727,6 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4727 |
missing_keys, unexpected_keys, error_msgs,
|
| 4728 |
)
|
| 4729 |
|
| 4730 |
-
def get_char_embeddings_vocab(self) -> torch.Tensor:
|
| 4731 |
-
"""
|
| 4732 |
-
Devuelve e_chars para todo el vocabulario [V, d], sin mezclar con
|
| 4733 |
-
token embeddings. Usado por el output path cuando tie_word_embeddings=False
|
| 4734 |
-
para aplicar spelling bee en la proyecciΓ³n de salida, siguiendo la
|
| 4735 |
-
implementaciΓ³n de referencia (littletrainingloop, spelling_bee_out).
|
| 4736 |
-
"""
|
| 4737 |
-
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
|
| 4738 |
-
e_chars_vocab = rope_bytes[
|
| 4739 |
-
self.token_bytes,
|
| 4740 |
-
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4741 |
-
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4742 |
-
return self.char_norm(
|
| 4743 |
-
e_chars_vocab.to(self.char_norm.weight.dtype)
|
| 4744 |
-
).to(self.byte_emb.weight.dtype) # [V, d]
|
| 4745 |
-
|
| 4746 |
def set_byte_table(self, tokenizer) -> None:
|
| 4747 |
"""
|
| 4748 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|
|
@@ -4852,7 +4836,7 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4852 |
# ββ Step 3b: LayerNorm over character embeddings (float32 for stability)
|
| 4853 |
# Mirrors reference impl (character_norm=True). Ensures E[βe_charsβΒ²]
|
| 4854 |
# matches E[βe_tokβΒ²] regardless of token byte-length distribution.
|
| 4855 |
-
e_chars_vocab = self.char_norm(e_chars_vocab.
|
| 4856 |
|
| 4857 |
# ββ Step 4: gather only the tokens present in this batch ββββββββββββ
|
| 4858 |
# This is the only BΓS operation β a single embedding lookup.
|
|
@@ -4894,7 +4878,7 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4894 |
self.token_bytes,
|
| 4895 |
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4896 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4897 |
-
e_chars_vocab = self.char_norm(e_chars_vocab.
|
| 4898 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4899 |
|
| 4900 |
|
|
@@ -5452,7 +5436,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5452 |
self.vocab_size = config.vocab_size
|
| 5453 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 5454 |
|
| 5455 |
-
if config.use_token_generator
|
| 5456 |
self._tied_weights_keys = {}
|
| 5457 |
|
| 5458 |
# ββ Analysis infrastructure βββββββββββββββββββββββββββββββββββββββ
|
|
@@ -5519,30 +5503,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5519 |
|
| 5520 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5521 |
|
| 5522 |
-
def _get_lm_head_weight(self) -> torch.Tensor:
|
| 5523 |
-
"""
|
| 5524 |
-
Devuelve el weight efectivo del lm_head.
|
| 5525 |
-
|
| 5526 |
-
- tie_word_embeddings=True: lm_head.weight == embed_tokens.weight
|
| 5527 |
-
(tying estΓ‘ndar, spelling bee solo en entrada).
|
| 5528 |
-
|
| 5529 |
-
- tie_word_embeddings=False + use_spelling_bee_embeddings=True:
|
| 5530 |
-
blendea lm_head.weight con e_chars del vocabulario completo,
|
| 5531 |
-
siguiendo la implementaciΓ³n de referencia del paper
|
| 5532 |
-
(littletrainingloop, spelling_bee_out, separate_token_embedding=False).
|
| 5533 |
-
El byte_emb compartido recibe gradiente desde ambos paths.
|
| 5534 |
-
|
| 5535 |
-
- tie_word_embeddings=False sin spelling bee:
|
| 5536 |
-
devuelve lm_head.weight directamente.
|
| 5537 |
-
"""
|
| 5538 |
-
weights = self.lm_head.weight # [V, d]
|
| 5539 |
-
if (not self.config.tie_word_embeddings
|
| 5540 |
-
and getattr(self.config, "use_spelling_bee_embeddings", False)
|
| 5541 |
-
and self.model.spelling_bee is not None):
|
| 5542 |
-
e_chars = self.model.spelling_bee.get_char_embeddings_vocab() # [V, d]
|
| 5543 |
-
weights = (weights + e_chars.to(weights.dtype)) * 0.5
|
| 5544 |
-
return weights
|
| 5545 |
-
|
| 5546 |
def get_input_embeddings(self):
|
| 5547 |
return self.model.get_input_embeddings()
|
| 5548 |
|
|
@@ -5622,7 +5582,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5622 |
loss = None
|
| 5623 |
if labels is not None:
|
| 5624 |
loss = compute_cce_loss(
|
| 5625 |
-
hidden_states, labels, self.
|
| 5626 |
getattr(self.lm_head, "bias", None), self.config.pad_token_id,
|
| 5627 |
)
|
| 5628 |
# Add JTok-M load-balancing auxiliary loss
|
|
@@ -5657,9 +5617,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5657 |
slice(-logits_to_keep, None)
|
| 5658 |
if isinstance(logits_to_keep, int) else logits_to_keep
|
| 5659 |
)
|
| 5660 |
-
logits =
|
| 5661 |
-
hidden_states[:, slice_indices, :], self._get_lm_head_weight()
|
| 5662 |
-
)
|
| 5663 |
|
| 5664 |
# ββ Finalise and store analysis state βββββββββββββββββββββββββββββ
|
| 5665 |
if analysis_state is not None:
|
|
|
|
| 4727 |
missing_keys, unexpected_keys, error_msgs,
|
| 4728 |
)
|
| 4729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4730 |
def set_byte_table(self, tokenizer) -> None:
|
| 4731 |
"""
|
| 4732 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|
|
|
|
| 4836 |
# ββ Step 3b: LayerNorm over character embeddings (float32 for stability)
|
| 4837 |
# Mirrors reference impl (character_norm=True). Ensures E[βe_charsβΒ²]
|
| 4838 |
# matches E[βe_tokβΒ²] regardless of token byte-length distribution.
|
| 4839 |
+
e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_embeds.dtype)
|
| 4840 |
|
| 4841 |
# ββ Step 4: gather only the tokens present in this batch ββββββββββββ
|
| 4842 |
# This is the only BΓS operation β a single embedding lookup.
|
|
|
|
| 4878 |
self.token_bytes,
|
| 4879 |
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4880 |
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4881 |
+
e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_emb_weight.dtype)
|
| 4882 |
return (token_emb_weight + e_chars_vocab) * 0.5
|
| 4883 |
|
| 4884 |
|
|
|
|
| 5436 |
self.vocab_size = config.vocab_size
|
| 5437 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 5438 |
|
| 5439 |
+
if config.use_token_generator:
|
| 5440 |
self._tied_weights_keys = {}
|
| 5441 |
|
| 5442 |
# ββ Analysis infrastructure βββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 5503 |
|
| 5504 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5506 |
def get_input_embeddings(self):
|
| 5507 |
return self.model.get_input_embeddings()
|
| 5508 |
|
|
|
|
| 5582 |
loss = None
|
| 5583 |
if labels is not None:
|
| 5584 |
loss = compute_cce_loss(
|
| 5585 |
+
hidden_states, labels, self.lm_head.weight,
|
| 5586 |
getattr(self.lm_head, "bias", None), self.config.pad_token_id,
|
| 5587 |
)
|
| 5588 |
# Add JTok-M load-balancing auxiliary loss
|
|
|
|
| 5617 |
slice(-logits_to_keep, None)
|
| 5618 |
if isinstance(logits_to_keep, int) else logits_to_keep
|
| 5619 |
)
|
| 5620 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
|
|
|
| 5621 |
|
| 5622 |
# ββ Finalise and store analysis state βββββββββββββββββββββββββββββ
|
| 5623 |
if analysis_state is not None:
|