KitsuVp commited on
Commit
031a455
·
verified ·
1 Parent(s): e7cfe46

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +45 -3
modeling_neollm.py CHANGED
@@ -4727,6 +4727,22 @@ class SpellingBeeEmbedding(nn.Module):
4727
  missing_keys, unexpected_keys, error_msgs,
4728
  )
4729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4730
  def set_byte_table(self, tokenizer) -> None:
4731
  """
4732
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
@@ -5436,7 +5452,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5436
  self.vocab_size = config.vocab_size
5437
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
5438
 
5439
- if config.use_token_generator:
5440
  self._tied_weights_keys = {}
5441
 
5442
  # ── Analysis infrastructure ───────────────────────────────────────
@@ -5503,6 +5519,30 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5503
 
5504
  # ── Standard model API ────────────────────────────────────────────────
5505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5506
  def get_input_embeddings(self):
5507
  return self.model.get_input_embeddings()
5508
 
@@ -5582,7 +5622,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5582
  loss = None
5583
  if labels is not None:
5584
  loss = compute_cce_loss(
5585
- hidden_states, labels, self.lm_head.weight,
5586
  getattr(self.lm_head, "bias", None), self.config.pad_token_id,
5587
  )
5588
  # Add JTok-M load-balancing auxiliary loss
@@ -5617,7 +5657,9 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5617
  slice(-logits_to_keep, None)
5618
  if isinstance(logits_to_keep, int) else logits_to_keep
5619
  )
5620
- logits = self.lm_head(hidden_states[:, slice_indices, :])
 
 
5621
 
5622
  # ── Finalise and store analysis state ─────────────────────────────
5623
  if analysis_state is not None:
 
4727
  missing_keys, unexpected_keys, error_msgs,
4728
  )
4729
 
4730
+ def get_char_embeddings_vocab(self) -> torch.Tensor:
4731
+ """
4732
+ Devuelve e_chars para todo el vocabulario [V, d], sin mezclar con
4733
+ token embeddings. Usado por el output path cuando tie_word_embeddings=False
4734
+ para aplicar spelling bee en la proyección de salida, siguiendo la
4735
+ implementación de referencia (littletrainingloop, spelling_bee_out).
4736
+ """
4737
+ rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
4738
+ e_chars_vocab = rope_bytes[
4739
+ self.token_bytes,
4740
+ torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4741
+ ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4742
+ return self.char_norm(
4743
+ e_chars_vocab.to(self.char_norm.weight.dtype)
4744
+ ).to(self.byte_emb.weight.dtype) # [V, d]
4745
+
4746
  def set_byte_table(self, tokenizer) -> None:
4747
  """
4748
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
 
5452
  self.vocab_size = config.vocab_size
5453
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
5454
 
5455
+ if config.use_token_generator or not config.tie_word_embeddings:
5456
  self._tied_weights_keys = {}
5457
 
5458
  # ── Analysis infrastructure ───────────────────────────────────────
 
5519
 
5520
  # ── Standard model API ────────────────────────────────────────────────
5521
 
5522
+ def _get_lm_head_weight(self) -> torch.Tensor:
5523
+ """
5524
+ Devuelve el weight efectivo del lm_head.
5525
+
5526
+ - tie_word_embeddings=True: lm_head.weight == embed_tokens.weight
5527
+ (tying estándar, spelling bee solo en entrada).
5528
+
5529
+ - tie_word_embeddings=False + use_spelling_bee_embeddings=True:
5530
+ blendea lm_head.weight con e_chars del vocabulario completo,
5531
+ siguiendo la implementación de referencia del paper
5532
+ (littletrainingloop, spelling_bee_out, separate_token_embedding=False).
5533
+ El byte_emb compartido recibe gradiente desde ambos paths.
5534
+
5535
+ - tie_word_embeddings=False sin spelling bee:
5536
+ devuelve lm_head.weight directamente.
5537
+ """
5538
+ weights = self.lm_head.weight # [V, d]
5539
+ if (not self.config.tie_word_embeddings
5540
+ and getattr(self.config, "use_spelling_bee_embeddings", False)
5541
+ and self.model.spelling_bee is not None):
5542
+ e_chars = self.model.spelling_bee.get_char_embeddings_vocab() # [V, d]
5543
+ weights = (weights + e_chars.to(weights.dtype)) * 0.5
5544
+ return weights
5545
+
5546
  def get_input_embeddings(self):
5547
  return self.model.get_input_embeddings()
5548
 
 
5622
  loss = None
5623
  if labels is not None:
5624
  loss = compute_cce_loss(
5625
+ hidden_states, labels, self._get_lm_head_weight(),
5626
  getattr(self.lm_head, "bias", None), self.config.pad_token_id,
5627
  )
5628
  # Add JTok-M load-balancing auxiliary loss
 
5657
  slice(-logits_to_keep, None)
5658
  if isinstance(logits_to_keep, int) else logits_to_keep
5659
  )
5660
+ logits = torch.nn.functional.linear(
5661
+ hidden_states[:, slice_indices, :], self._get_lm_head_weight()
5662
+ )
5663
 
5664
  # ── Finalise and store analysis state ─────────────────────────────
5665
  if analysis_state is not None: