Update modeling_neollm.py
Browse files- modeling_neollm.py +45 -3
modeling_neollm.py
CHANGED
|
@@ -4727,6 +4727,22 @@ class SpellingBeeEmbedding(nn.Module):
|
|
| 4727 |
missing_keys, unexpected_keys, error_msgs,
|
| 4728 |
)
|
| 4729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4730 |
def set_byte_table(self, tokenizer) -> None:
|
| 4731 |
"""
|
| 4732 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|
|
@@ -5436,7 +5452,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5436 |
self.vocab_size = config.vocab_size
|
| 5437 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 5438 |
|
| 5439 |
-
if config.use_token_generator:
|
| 5440 |
self._tied_weights_keys = {}
|
| 5441 |
|
| 5442 |
# ── Analysis infrastructure ───────────────────────────────────────
|
|
@@ -5503,6 +5519,30 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5503 |
|
| 5504 |
# ── Standard model API ────────────────────────────────────────────────
|
| 5505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5506 |
def get_input_embeddings(self):
|
| 5507 |
return self.model.get_input_embeddings()
|
| 5508 |
|
|
@@ -5582,7 +5622,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5582 |
loss = None
|
| 5583 |
if labels is not None:
|
| 5584 |
loss = compute_cce_loss(
|
| 5585 |
-
hidden_states, labels, self.
|
| 5586 |
getattr(self.lm_head, "bias", None), self.config.pad_token_id,
|
| 5587 |
)
|
| 5588 |
# Add JTok-M load-balancing auxiliary loss
|
|
@@ -5617,7 +5657,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5617 |
slice(-logits_to_keep, None)
|
| 5618 |
if isinstance(logits_to_keep, int) else logits_to_keep
|
| 5619 |
)
|
| 5620 |
-
logits =
|
|
|
|
|
|
|
| 5621 |
|
| 5622 |
# ── Finalise and store analysis state ─────────────────────────────
|
| 5623 |
if analysis_state is not None:
|
|
|
|
| 4727 |
missing_keys, unexpected_keys, error_msgs,
|
| 4728 |
)
|
| 4729 |
|
| 4730 |
+
def get_char_embeddings_vocab(self) -> torch.Tensor:
|
| 4731 |
+
"""
|
| 4732 |
+
Devuelve e_chars para todo el vocabulario [V, d], sin mezclar con
|
| 4733 |
+
token embeddings. Usado por el output path cuando tie_word_embeddings=False
|
| 4734 |
+
para aplicar spelling bee en la proyección de salida, siguiendo la
|
| 4735 |
+
implementación de referencia (littletrainingloop, spelling_bee_out).
|
| 4736 |
+
"""
|
| 4737 |
+
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
|
| 4738 |
+
e_chars_vocab = rope_bytes[
|
| 4739 |
+
self.token_bytes,
|
| 4740 |
+
torch.arange(self.MAX_BYTES, device=rope_bytes.device),
|
| 4741 |
+
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
|
| 4742 |
+
return self.char_norm(
|
| 4743 |
+
e_chars_vocab.to(self.char_norm.weight.dtype)
|
| 4744 |
+
).to(self.byte_emb.weight.dtype) # [V, d]
|
| 4745 |
+
|
| 4746 |
def set_byte_table(self, tokenizer) -> None:
|
| 4747 |
"""
|
| 4748 |
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
|
|
|
|
| 5452 |
self.vocab_size = config.vocab_size
|
| 5453 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 5454 |
|
| 5455 |
+
if config.use_token_generator or not config.tie_word_embeddings:
|
| 5456 |
self._tied_weights_keys = {}
|
| 5457 |
|
| 5458 |
# ── Analysis infrastructure ───────────────────────────────────────
|
|
|
|
| 5519 |
|
| 5520 |
# ── Standard model API ────────────────────────────────────────────────
|
| 5521 |
|
| 5522 |
+
def _get_lm_head_weight(self) -> torch.Tensor:
|
| 5523 |
+
"""
|
| 5524 |
+
Devuelve el weight efectivo del lm_head.
|
| 5525 |
+
|
| 5526 |
+
- tie_word_embeddings=True: lm_head.weight == embed_tokens.weight
|
| 5527 |
+
(tying estándar, spelling bee solo en entrada).
|
| 5528 |
+
|
| 5529 |
+
- tie_word_embeddings=False + use_spelling_bee_embeddings=True:
|
| 5530 |
+
blendea lm_head.weight con e_chars del vocabulario completo,
|
| 5531 |
+
siguiendo la implementación de referencia del paper
|
| 5532 |
+
(littletrainingloop, spelling_bee_out, separate_token_embedding=False).
|
| 5533 |
+
El byte_emb compartido recibe gradiente desde ambos paths.
|
| 5534 |
+
|
| 5535 |
+
- tie_word_embeddings=False sin spelling bee:
|
| 5536 |
+
devuelve lm_head.weight directamente.
|
| 5537 |
+
"""
|
| 5538 |
+
weights = self.lm_head.weight # [V, d]
|
| 5539 |
+
if (not self.config.tie_word_embeddings
|
| 5540 |
+
and getattr(self.config, "use_spelling_bee_embeddings", False)
|
| 5541 |
+
and self.model.spelling_bee is not None):
|
| 5542 |
+
e_chars = self.model.spelling_bee.get_char_embeddings_vocab() # [V, d]
|
| 5543 |
+
weights = (weights + e_chars.to(weights.dtype)) * 0.5
|
| 5544 |
+
return weights
|
| 5545 |
+
|
| 5546 |
def get_input_embeddings(self):
|
| 5547 |
return self.model.get_input_embeddings()
|
| 5548 |
|
|
|
|
| 5622 |
loss = None
|
| 5623 |
if labels is not None:
|
| 5624 |
loss = compute_cce_loss(
|
| 5625 |
+
hidden_states, labels, self._get_lm_head_weight(),
|
| 5626 |
getattr(self.lm_head, "bias", None), self.config.pad_token_id,
|
| 5627 |
)
|
| 5628 |
# Add JTok-M load-balancing auxiliary loss
|
|
|
|
| 5657 |
slice(-logits_to_keep, None)
|
| 5658 |
if isinstance(logits_to_keep, int) else logits_to_keep
|
| 5659 |
)
|
| 5660 |
+
logits = torch.nn.functional.linear(
|
| 5661 |
+
hidden_states[:, slice_indices, :], self._get_lm_head_weight()
|
| 5662 |
+
)
|
| 5663 |
|
| 5664 |
# ── Finalise and store analysis state ─────────────────────────────
|
| 5665 |
if analysis_state is not None:
|