KitsuVp commited on
Commit
ff39a46
Β·
verified Β·
1 Parent(s): 031a455

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +5 -47
modeling_neollm.py CHANGED
@@ -4727,22 +4727,6 @@ class SpellingBeeEmbedding(nn.Module):
4727
  missing_keys, unexpected_keys, error_msgs,
4728
  )
4729
 
4730
- def get_char_embeddings_vocab(self) -> torch.Tensor:
4731
- """
4732
- Devuelve e_chars para todo el vocabulario [V, d], sin mezclar con
4733
- token embeddings. Usado por el output path cuando tie_word_embeddings=False
4734
- para aplicar spelling bee en la proyecciΓ³n de salida, siguiendo la
4735
- implementaciΓ³n de referencia (littletrainingloop, spelling_bee_out).
4736
- """
4737
- rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
4738
- e_chars_vocab = rope_bytes[
4739
- self.token_bytes,
4740
- torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4741
- ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4742
- return self.char_norm(
4743
- e_chars_vocab.to(self.char_norm.weight.dtype)
4744
- ).to(self.byte_emb.weight.dtype) # [V, d]
4745
-
4746
  def set_byte_table(self, tokenizer) -> None:
4747
  """
4748
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
@@ -4852,7 +4836,7 @@ class SpellingBeeEmbedding(nn.Module):
4852
  # ── Step 3b: LayerNorm over character embeddings (float32 for stability)
4853
  # Mirrors reference impl (character_norm=True). Ensures E[β€–e_charsβ€–Β²]
4854
  # matches E[β€–e_tokβ€–Β²] regardless of token byte-length distribution.
4855
- e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_embeds.dtype)
4856
 
4857
  # ── Step 4: gather only the tokens present in this batch ────────────
4858
  # This is the only BΓ—S operation β€” a single embedding lookup.
@@ -4894,7 +4878,7 @@ class SpellingBeeEmbedding(nn.Module):
4894
  self.token_bytes,
4895
  torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4896
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4897
- e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_emb_weight.dtype)
4898
  return (token_emb_weight + e_chars_vocab) * 0.5
4899
 
4900
 
@@ -5452,7 +5436,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5452
  self.vocab_size = config.vocab_size
5453
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
5454
 
5455
- if config.use_token_generator or not config.tie_word_embeddings:
5456
  self._tied_weights_keys = {}
5457
 
5458
  # ── Analysis infrastructure ───────────────────────────────────────
@@ -5519,30 +5503,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5519
 
5520
  # ── Standard model API ────────────────────────────────────────────────
5521
 
5522
- def _get_lm_head_weight(self) -> torch.Tensor:
5523
- """
5524
- Devuelve el weight efectivo del lm_head.
5525
-
5526
- - tie_word_embeddings=True: lm_head.weight == embed_tokens.weight
5527
- (tying estΓ‘ndar, spelling bee solo en entrada).
5528
-
5529
- - tie_word_embeddings=False + use_spelling_bee_embeddings=True:
5530
- blendea lm_head.weight con e_chars del vocabulario completo,
5531
- siguiendo la implementaciΓ³n de referencia del paper
5532
- (littletrainingloop, spelling_bee_out, separate_token_embedding=False).
5533
- El byte_emb compartido recibe gradiente desde ambos paths.
5534
-
5535
- - tie_word_embeddings=False sin spelling bee:
5536
- devuelve lm_head.weight directamente.
5537
- """
5538
- weights = self.lm_head.weight # [V, d]
5539
- if (not self.config.tie_word_embeddings
5540
- and getattr(self.config, "use_spelling_bee_embeddings", False)
5541
- and self.model.spelling_bee is not None):
5542
- e_chars = self.model.spelling_bee.get_char_embeddings_vocab() # [V, d]
5543
- weights = (weights + e_chars.to(weights.dtype)) * 0.5
5544
- return weights
5545
-
5546
  def get_input_embeddings(self):
5547
  return self.model.get_input_embeddings()
5548
 
@@ -5622,7 +5582,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5622
  loss = None
5623
  if labels is not None:
5624
  loss = compute_cce_loss(
5625
- hidden_states, labels, self._get_lm_head_weight(),
5626
  getattr(self.lm_head, "bias", None), self.config.pad_token_id,
5627
  )
5628
  # Add JTok-M load-balancing auxiliary loss
@@ -5657,9 +5617,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5657
  slice(-logits_to_keep, None)
5658
  if isinstance(logits_to_keep, int) else logits_to_keep
5659
  )
5660
- logits = torch.nn.functional.linear(
5661
- hidden_states[:, slice_indices, :], self._get_lm_head_weight()
5662
- )
5663
 
5664
  # ── Finalise and store analysis state ─────────────────────────────
5665
  if analysis_state is not None:
 
4727
  missing_keys, unexpected_keys, error_msgs,
4728
  )
4729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4730
  def set_byte_table(self, tokenizer) -> None:
4731
  """
4732
  Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
 
4836
  # ── Step 3b: LayerNorm over character embeddings (float32 for stability)
4837
  # Mirrors reference impl (character_norm=True). Ensures E[β€–e_charsβ€–Β²]
4838
  # matches E[β€–e_tokβ€–Β²] regardless of token byte-length distribution.
4839
+ e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_embeds.dtype)
4840
 
4841
  # ── Step 4: gather only the tokens present in this batch ────────────
4842
  # This is the only BΓ—S operation β€” a single embedding lookup.
 
4878
  self.token_bytes,
4879
  torch.arange(self.MAX_BYTES, device=rope_bytes.device),
4880
  ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
4881
+ e_chars_vocab = self.char_norm(e_chars_vocab.float()).to(token_emb_weight.dtype)
4882
  return (token_emb_weight + e_chars_vocab) * 0.5
4883
 
4884
 
 
5436
  self.vocab_size = config.vocab_size
5437
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
5438
 
5439
+ if config.use_token_generator:
5440
  self._tied_weights_keys = {}
5441
 
5442
  # ── Analysis infrastructure ───────────────────────────────────────
 
5503
 
5504
  # ── Standard model API ────────────────────────────────────────────────
5505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5506
  def get_input_embeddings(self):
5507
  return self.model.get_input_embeddings()
5508
 
 
5582
  loss = None
5583
  if labels is not None:
5584
  loss = compute_cce_loss(
5585
+ hidden_states, labels, self.lm_head.weight,
5586
  getattr(self.lm_head, "bias", None), self.config.pad_token_id,
5587
  )
5588
  # Add JTok-M load-balancing auxiliary loss
 
5617
  slice(-logits_to_keep, None)
5618
  if isinstance(logits_to_keep, int) else logits_to_keep
5619
  )
5620
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
 
 
5621
 
5622
  # ── Finalise and store analysis state ─────────────────────────────
5623
  if analysis_state is not None: