KitsuVp commited on
Commit
d086557
Β·
verified Β·
1 Parent(s): 9fd2b0d

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +278 -60
modeling_neollm.py CHANGED
@@ -5458,7 +5458,13 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5458
  self.hidden_size = config.hidden_size
5459
  self.layer_idx = layer_idx
5460
  self.use_jtokm = config.use_jtokm
5461
- self.use_seednorm = bool(getattr(config, "use_seednorm", True))
 
 
 
 
 
 
5462
  # Controls only the first pre-attention normalisation applied directly
5463
  # to the embedding stream. Defaults to True for checkpoint/config
5464
  # backward compatibility. When False, layer 0 does not instantiate
@@ -5467,7 +5473,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5467
  self.use_embedding_input_norm = bool(
5468
  getattr(config, "use_embedding_input_norm", True)
5469
  )
5470
- self.has_input_layernorm = not (
5471
  self.layer_idx == 0 and not self.use_embedding_input_norm
5472
  )
5473
 
@@ -5478,24 +5484,55 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5478
  else NeoLLMMLP(config)
5479
  )
5480
  self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False)
5481
- self.input_layernorm = (
5482
- _make_norm(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5483
  config.hidden_size,
5484
  eps=config.rms_norm_eps,
5485
  use_seednorm=self.use_seednorm,
5486
  )
5487
- if self.has_input_layernorm
5488
- else None
5489
- )
5490
- self.post_attention_layernorm = _make_norm(
5491
- config.hidden_size,
5492
- eps=config.rms_norm_eps,
5493
- use_seednorm=self.use_seednorm,
5494
- )
5495
- self.lns_attn = LNS(layer_idx)
5496
- self.lns_mlp = LNS(layer_idx)
5497
- self.gpas_attn = GPAS(config.hidden_size)
5498
- self.gpas_mlp = GPAS(config.hidden_size)
5499
  self.current_layer_fan = None
5500
 
5501
  # ── StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ────────
@@ -5750,6 +5787,129 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5750
  # LAUREL-LR (paper eq. 3): f(x) + BAx + x
5751
  return delta + lr_delta + residual
5752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5753
  def forward(
5754
  self,
5755
  hidden_states: torch.Tensor,
@@ -5830,12 +5990,16 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5830
  k_normed = _apply_norm(self.input_layernorm, dca_k_input, analysis=None)
5831
  v_normed = _apply_norm(self.input_layernorm, dca_v_input, analysis=None)
5832
 
5833
- h_lns = self.lns_attn(h_normed)
5834
  if dca_k_input is not None:
5835
- k_lns = self.lns_attn(k_normed)
5836
- v_lns = self.lns_attn(v_normed)
 
 
 
 
5837
  dca_key_value_states = (k_lns, v_lns)
5838
- if layer_analysis is not None:
5839
  layer_analysis.lns_attn_output = h_lns.detach()
5840
 
5841
  hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
@@ -5870,7 +6034,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5870
  else:
5871
  attn_aug = residual_attn + hidden_states
5872
  dca_final_residual = hidden_states if self.use_dca else None
5873
- h_tilde = self.gpas_attn(attn_aug, analysis=gpas_attn_a)
 
 
 
 
5874
 
5875
  if layer_analysis is not None:
5876
  layer_analysis.h_tilde = h_tilde.detach()
@@ -5892,8 +6060,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5892
  # ── MLP block ─────────────────────────────────────────────────────
5893
  sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
5894
  h_normed2 = _apply_norm(self.post_attention_layernorm, h_mlp, analysis=sn_post)
5895
- h_lns2 = self.lns_mlp(h_normed2)
5896
- if layer_analysis is not None:
5897
  layer_analysis.lns_mlp_output = h_lns2.detach()
5898
 
5899
  mlp_a = layer_analysis.mlp if layer_analysis is not None else None
@@ -5934,11 +6102,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
5934
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
5935
  delta_r = delta_r.reshape(orig_shape)
5936
 
5937
- gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
5938
- hidden_states = self.gpas_mlp(mlp_aug + delta_r, analysis=gpas_mlp_a)
 
 
 
 
5939
  else:
5940
- gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
5941
- hidden_states = self.gpas_mlp(mlp_aug, analysis=gpas_mlp_a)
 
 
 
 
5942
 
5943
  if layer_analysis is not None:
5944
  layer_analysis.hidden_states_output = hidden_states.detach()
@@ -6378,6 +6554,8 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
6378
 
6379
  module.res_weight.data.fill_(1.0)
6380
  elif isinstance(module, NeoLLMDecoderLayer):
 
 
6381
  # AttnRes pseudo-queries: MUST be initialized to zero.
6382
  # Zero initialization ensures uniform attention weights at step 0
6383
  # (softmax of zeros is uniform), making AttnRes equivalent to a
@@ -6468,11 +6646,21 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6468
  [NeoLLMDecoderLayer(config, layer_idx)
6469
  for layer_idx in range(config.num_hidden_layers)]
6470
  )
6471
- self.norm = _make_norm(
6472
- config.hidden_size,
6473
- eps=config.rms_norm_eps,
6474
- use_seednorm=bool(getattr(config, "use_seednorm", True)),
6475
- )
 
 
 
 
 
 
 
 
 
 
6476
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
6477
  self.gradient_checkpointing = False
6478
  self.first_layer_fan = None if getattr(config, "use_fan_residual", True) else False
@@ -6564,8 +6752,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6564
  polynorm = PolyNormAnalysis() if not _versatile else None,
6565
  versatile = VersatileFFNAnalysis() if _versatile else None,
6566
  ),
6567
- gpas_attn = GPASAnalysis(),
6568
- gpas_mlp = GPASAnalysis(),
6569
  jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
6570
  dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
6571
  attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
@@ -6658,6 +6846,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6658
  )
6659
 
6660
  hidden_states = inputs_embeds
 
 
6661
  all_hidden_states = () if output_hidden_states else None
6662
  all_attentions = () if output_attentions else None
6663
  all_aux_stats = []
@@ -6770,23 +6960,43 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6770
 
6771
  dca_layer_sources = self._select_dca_sources(dca_sources) if use_dca else None
6772
 
6773
- layer_outputs = decoder_layer(
6774
- hidden_states,
6775
- position_embeddings=position_embeddings,
6776
- attention_mask=causal_mask,
6777
- first_layer_fan=self.first_layer_fan,
6778
- z_tilde=z_tilde,
6779
- B_vals=B_vals,
6780
- dca_sources=dca_layer_sources,
6781
- attn_res_sources=attn_res_sources,
6782
- attn_res_partial=attn_res_partial if use_attn_res else None,
6783
- layer_analysis=layer_analysis,
6784
- output_attentions=output_attentions,
6785
- repo_rope_args=repo_rope_args,
6786
- position_ids=position_ids,
6787
- **kwargs,
6788
- )
6789
- hidden_states = layer_outputs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6790
 
6791
  # Update AttnRes partial sum β€” the new partial is the layer output
6792
  if use_attn_res:
@@ -6797,18 +7007,20 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6797
  dca_sources = dca_sources + [hidden_states]
6798
 
6799
  if output_attentions:
6800
- all_attentions = all_attentions + (layer_outputs[1],)
 
6801
 
6802
- # Collect JTok-M aux stats (last element if present)
6803
- if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1):
6804
- all_aux_stats.append(layer_outputs[-1])
 
 
 
6805
 
6806
- # Collect VersatileFFN aux stats (second-to-last if jtokm also present,
6807
- # or last if jtokm is absent). Only non-None during training.
6808
  if getattr(self.config, "use_versatile_ffn", False):
6809
- for item in layer_outputs[1:]:
6810
  if isinstance(item, tuple) and len(item) == 3:
6811
- # (p_sum, f_sum, N_tokens) signature
6812
  all_aux_stats.append(("versatile", item))
6813
  break
6814
 
@@ -6817,7 +7029,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
6817
  and hasattr(decoder_layer, "current_layer_fan")):
6818
  self.first_layer_fan = decoder_layer.current_layer_fan
6819
 
6820
- hidden_states = self.norm(hidden_states)
 
 
 
 
 
 
6821
 
6822
  if output_hidden_states:
6823
  all_hidden_states = all_hidden_states + (hidden_states,)
 
5458
  self.hidden_size = config.hidden_size
5459
  self.layer_idx = layer_idx
5460
  self.use_jtokm = config.use_jtokm
5461
+ self.use_seednorm = bool(getattr(config, "use_seednorm", True))
5462
+ self.use_lns = bool(getattr(config, "use_lns", False))
5463
+ self.use_gpas = bool(getattr(config, "use_gpas", False))
5464
+ self.use_siamesenorm = bool(getattr(config, "use_siamesenorm", False))
5465
+ self.siamese_normalized_input = bool(getattr(config, "siamese_normalized_input", True))
5466
+ self.siamese_depth_scaling = bool(getattr(config, "siamese_depth_scaling", True))
5467
+ self.siamese_attn_x_scale_init = float(getattr(config, "siamese_attn_x_scale_init", 1.0))
5468
  # Controls only the first pre-attention normalisation applied directly
5469
  # to the embedding stream. Defaults to True for checkpoint/config
5470
  # backward compatibility. When False, layer 0 does not instantiate
 
5473
  self.use_embedding_input_norm = bool(
5474
  getattr(config, "use_embedding_input_norm", True)
5475
  )
5476
+ self.has_input_layernorm = (not self.use_siamesenorm) and not (
5477
  self.layer_idx == 0 and not self.use_embedding_input_norm
5478
  )
5479
 
 
5484
  else NeoLLMMLP(config)
5485
  )
5486
  self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False)
5487
+ if self.use_siamesenorm:
5488
+ self.input_layernorm = None
5489
+ self.post_attention_layernorm = None
5490
+
5491
+ # SiameseNorm is RMS-only by config validation. These modules are
5492
+ # constructed only when the Siamese topology is active, so no
5493
+ # inactive SeeDNorm/RMSNorm pre-norm modules remain in the graph.
5494
+ self.siamese_attn_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5495
+ self.siamese_attn_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5496
+ self.siamese_mlp_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5497
+ self.siamese_mlp_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5498
+ self.siamese_attn_input_norm = (
5499
+ nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5500
+ if self.siamese_normalized_input else None
5501
+ )
5502
+ self.siamese_mlp_input_norm = (
5503
+ nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5504
+ if self.siamese_normalized_input else None
5505
+ )
5506
+ self.siamese_attn_x_scale = nn.Parameter(
5507
+ torch.full((config.hidden_size,), self.siamese_attn_x_scale_init, dtype=torch.float32)
5508
+ )
5509
+ else:
5510
+ self.input_layernorm = (
5511
+ _make_norm(
5512
+ config.hidden_size,
5513
+ eps=config.rms_norm_eps,
5514
+ use_seednorm=self.use_seednorm,
5515
+ )
5516
+ if self.has_input_layernorm
5517
+ else None
5518
+ )
5519
+ self.post_attention_layernorm = _make_norm(
5520
  config.hidden_size,
5521
  eps=config.rms_norm_eps,
5522
  use_seednorm=self.use_seednorm,
5523
  )
5524
+ self.siamese_attn_x_norm = None
5525
+ self.siamese_attn_y_norm = None
5526
+ self.siamese_mlp_x_norm = None
5527
+ self.siamese_mlp_y_norm = None
5528
+ self.siamese_attn_input_norm = None
5529
+ self.siamese_mlp_input_norm = None
5530
+ self.siamese_attn_x_scale = None
5531
+
5532
+ self.lns_attn = LNS(layer_idx) if self.use_lns else None
5533
+ self.lns_mlp = LNS(layer_idx) if self.use_lns else None
5534
+ self.gpas_attn = GPAS(config.hidden_size) if self.use_gpas else None
5535
+ self.gpas_mlp = GPAS(config.hidden_size) if self.use_gpas else None
5536
  self.current_layer_fan = None
5537
 
5538
  # ── StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ────────
 
5787
  # LAUREL-LR (paper eq. 3): f(x) + BAx + x
5788
  return delta + lr_delta + residual
5789
 
5790
+ def _siamese_stream_scale(self, ref: torch.Tensor) -> torch.Tensor:
5791
+ if not self.siamese_depth_scaling:
5792
+ return ref.new_tensor(1.0)
5793
+ return ref.new_tensor(1.0 / math.sqrt(2.0 * float(self.layer_idx + 1)))
5794
+
5795
+ def forward_siamesenorm(
5796
+ self,
5797
+ x_states: torch.Tensor,
5798
+ y_states: torch.Tensor,
5799
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
5800
+ attention_mask: Optional[torch.Tensor] = None,
5801
+ first_layer_fan: Optional[torch.Tensor] = None,
5802
+ z_tilde: Optional[torch.Tensor] = None,
5803
+ B_vals: Optional[torch.Tensor] = None,
5804
+ layer_analysis: Optional[LayerAnalysis] = None,
5805
+ output_attentions: Optional[bool] = False,
5806
+ repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
5807
+ position_ids: Optional[torch.LongTensor] = None,
5808
+ **kwargs: Unpack[FlashAttentionKwargs],
5809
+ ) -> Tuple:
5810
+ # SiameseNorm keeps two coupled streams with shared Attention/MLP
5811
+ # parameters. All Siamese normalization modules are RMSNorm by
5812
+ # construction; SeeDNorm is rejected at config validation time.
5813
+ if layer_analysis is not None:
5814
+ layer_analysis.hidden_states_input = x_states.detach()
5815
+
5816
+ # ── Attention shared block ────────────────────────────────────────
5817
+ sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
5818
+ x_attn_norm = self.siamese_attn_x_norm(x_states)
5819
+ y_attn_norm = self.siamese_attn_y_norm(y_states)
5820
+ if sn_pre is not None:
5821
+ sn_pre.output = x_attn_norm.detach()
5822
+
5823
+ x_scale = self.siamese_attn_x_scale.to(dtype=x_attn_norm.dtype, device=x_attn_norm.device)
5824
+ h_attn = x_scale * x_attn_norm + y_attn_norm
5825
+ if self.siamese_attn_input_norm is not None:
5826
+ h_attn = self.siamese_attn_input_norm(h_attn)
5827
+
5828
+ h_lns = self.lns_attn(h_attn) if self.use_lns else h_attn
5829
+ if layer_analysis is not None and self.use_lns:
5830
+ layer_analysis.lns_attn_output = h_lns.detach()
5831
+
5832
+ attn_out, attn_weights, self.current_layer_fan = self.self_attn(
5833
+ hidden_states=h_lns,
5834
+ key_value_states=None,
5835
+ attention_mask=attention_mask,
5836
+ position_embeddings=position_embeddings,
5837
+ first_layer_fan=first_layer_fan,
5838
+ attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
5839
+ repo_rope_args=repo_rope_args,
5840
+ position_ids=position_ids,
5841
+ **kwargs,
5842
+ )
5843
+
5844
+ if layer_analysis is not None:
5845
+ layer_analysis.attn_contribution = attn_out.detach()
5846
+
5847
+ stream_scale = self._siamese_stream_scale(attn_out)
5848
+ x_after_attn = x_states + stream_scale * attn_out
5849
+ y_after_attn = y_states + attn_out
5850
+
5851
+ gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
5852
+ if self.use_gpas:
5853
+ x_after_attn = self.gpas_attn(x_after_attn, analysis=gpas_attn_a)
5854
+
5855
+ if layer_analysis is not None:
5856
+ layer_analysis.h_tilde = x_after_attn.detach()
5857
+
5858
+ # ── MLP shared block ──────────────────────────────────────────────
5859
+ sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
5860
+ x_mlp_norm = self.siamese_mlp_x_norm(x_after_attn)
5861
+ y_mlp_norm = self.siamese_mlp_y_norm(y_after_attn)
5862
+ if sn_post is not None:
5863
+ sn_post.output = x_mlp_norm.detach()
5864
+
5865
+ h_mlp = x_mlp_norm + y_mlp_norm
5866
+ if self.siamese_mlp_input_norm is not None:
5867
+ h_mlp = self.siamese_mlp_input_norm(h_mlp)
5868
+
5869
+ h_lns2 = self.lns_mlp(h_mlp) if self.use_lns else h_mlp
5870
+ if layer_analysis is not None and self.use_lns:
5871
+ layer_analysis.lns_mlp_output = h_lns2.detach()
5872
+
5873
+ mlp_a = layer_analysis.mlp if layer_analysis is not None else None
5874
+ if self.use_versatile_ffn:
5875
+ delta_m, versatile_aux = self.mlp(h_lns2, analysis=mlp_a)
5876
+ else:
5877
+ delta_m = self.mlp(h_lns2, analysis=mlp_a)
5878
+ versatile_aux = None
5879
+
5880
+ if layer_analysis is not None:
5881
+ layer_analysis.mlp_contribution = delta_m.detach()
5882
+
5883
+ shared_update = delta_m
5884
+ aux_stats = None
5885
+ if self.use_jtokm and z_tilde is not None and B_vals is not None:
5886
+ orig_shape = x_after_attn.shape
5887
+ h_flat = x_after_attn.reshape(-1, self.hidden_size)
5888
+ z_flat = z_tilde.reshape(-1, z_tilde.shape[-1])
5889
+ B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1])
5890
+ jtokm_a = layer_analysis.jtokm if layer_analysis is not None else None
5891
+ delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
5892
+ shared_update = shared_update + delta_r.reshape(orig_shape)
5893
+
5894
+ x_next = x_after_attn + stream_scale * shared_update
5895
+ y_next = y_after_attn + shared_update
5896
+
5897
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
5898
+ if self.use_gpas:
5899
+ x_next = self.gpas_mlp(x_next, analysis=gpas_mlp_a)
5900
+
5901
+ if layer_analysis is not None:
5902
+ layer_analysis.hidden_states_output = x_next.detach()
5903
+
5904
+ outputs = (x_next, y_next)
5905
+ if output_attentions:
5906
+ outputs += (attn_weights,)
5907
+ if aux_stats is not None:
5908
+ outputs += (aux_stats,)
5909
+ if versatile_aux is not None:
5910
+ outputs += (versatile_aux,)
5911
+ return outputs
5912
+
5913
  def forward(
5914
  self,
5915
  hidden_states: torch.Tensor,
 
5990
  k_normed = _apply_norm(self.input_layernorm, dca_k_input, analysis=None)
5991
  v_normed = _apply_norm(self.input_layernorm, dca_v_input, analysis=None)
5992
 
5993
+ h_lns = self.lns_attn(h_normed) if self.use_lns else h_normed
5994
  if dca_k_input is not None:
5995
+ if self.use_lns:
5996
+ k_lns = self.lns_attn(k_normed)
5997
+ v_lns = self.lns_attn(v_normed)
5998
+ else:
5999
+ k_lns = k_normed
6000
+ v_lns = v_normed
6001
  dca_key_value_states = (k_lns, v_lns)
6002
+ if layer_analysis is not None and self.use_lns:
6003
  layer_analysis.lns_attn_output = h_lns.detach()
6004
 
6005
  hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
 
6034
  else:
6035
  attn_aug = residual_attn + hidden_states
6036
  dca_final_residual = hidden_states if self.use_dca else None
6037
+ h_tilde = (
6038
+ self.gpas_attn(attn_aug, analysis=gpas_attn_a)
6039
+ if self.use_gpas
6040
+ else attn_aug
6041
+ )
6042
 
6043
  if layer_analysis is not None:
6044
  layer_analysis.h_tilde = h_tilde.detach()
 
6060
  # ── MLP block ─────────────────────────────────────────────────────
6061
  sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
6062
  h_normed2 = _apply_norm(self.post_attention_layernorm, h_mlp, analysis=sn_post)
6063
+ h_lns2 = self.lns_mlp(h_normed2) if self.use_lns else h_normed2
6064
+ if layer_analysis is not None and self.use_lns:
6065
  layer_analysis.lns_mlp_output = h_lns2.detach()
6066
 
6067
  mlp_a = layer_analysis.mlp if layer_analysis is not None else None
 
6102
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
6103
  delta_r = delta_r.reshape(orig_shape)
6104
 
6105
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
6106
+ hidden_states = (
6107
+ self.gpas_mlp(mlp_aug + delta_r, analysis=gpas_mlp_a)
6108
+ if self.use_gpas
6109
+ else mlp_aug + delta_r
6110
+ )
6111
  else:
6112
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
6113
+ hidden_states = (
6114
+ self.gpas_mlp(mlp_aug, analysis=gpas_mlp_a)
6115
+ if self.use_gpas
6116
+ else mlp_aug
6117
+ )
6118
 
6119
  if layer_analysis is not None:
6120
  layer_analysis.hidden_states_output = hidden_states.detach()
 
6554
 
6555
  module.res_weight.data.fill_(1.0)
6556
  elif isinstance(module, NeoLLMDecoderLayer):
6557
+ if hasattr(module, "siamese_attn_x_scale") and module.siamese_attn_x_scale is not None:
6558
+ module.siamese_attn_x_scale.data.fill_(module.siamese_attn_x_scale_init)
6559
  # AttnRes pseudo-queries: MUST be initialized to zero.
6560
  # Zero initialization ensures uniform attention weights at step 0
6561
  # (softmax of zeros is uniform), making AttnRes equivalent to a
 
6646
  [NeoLLMDecoderLayer(config, layer_idx)
6647
  for layer_idx in range(config.num_hidden_layers)]
6648
  )
6649
+ self.use_siamesenorm = bool(getattr(config, "use_siamesenorm", False))
6650
+ if self.use_siamesenorm:
6651
+ self.norm = None
6652
+ self.siamese_x_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
6653
+ self.siamese_y_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
6654
+ self.siamese_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
6655
+ else:
6656
+ self.norm = _make_norm(
6657
+ config.hidden_size,
6658
+ eps=config.rms_norm_eps,
6659
+ use_seednorm=bool(getattr(config, "use_seednorm", True)),
6660
+ )
6661
+ self.siamese_x_final_norm = None
6662
+ self.siamese_y_final_norm = None
6663
+ self.siamese_final_norm = None
6664
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
6665
  self.gradient_checkpointing = False
6666
  self.first_layer_fan = None if getattr(config, "use_fan_residual", True) else False
 
6752
  polynorm = PolyNormAnalysis() if not _versatile else None,
6753
  versatile = VersatileFFNAnalysis() if _versatile else None,
6754
  ),
6755
+ gpas_attn = GPASAnalysis() if getattr(cfg, "use_gpas", False) else None,
6756
+ gpas_mlp = GPASAnalysis() if getattr(cfg, "use_gpas", False) else None,
6757
  jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
6758
  dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
6759
  attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
 
6846
  )
6847
 
6848
  hidden_states = inputs_embeds
6849
+ use_siamesenorm = bool(getattr(self.config, "use_siamesenorm", False))
6850
+ siamese_y_states = inputs_embeds if use_siamesenorm else None
6851
  all_hidden_states = () if output_hidden_states else None
6852
  all_attentions = () if output_attentions else None
6853
  all_aux_stats = []
 
6960
 
6961
  dca_layer_sources = self._select_dca_sources(dca_sources) if use_dca else None
6962
 
6963
+ if use_siamesenorm:
6964
+ layer_outputs = decoder_layer.forward_siamesenorm(
6965
+ hidden_states,
6966
+ siamese_y_states,
6967
+ position_embeddings=position_embeddings,
6968
+ attention_mask=causal_mask,
6969
+ first_layer_fan=self.first_layer_fan,
6970
+ z_tilde=z_tilde,
6971
+ B_vals=B_vals,
6972
+ layer_analysis=layer_analysis,
6973
+ output_attentions=output_attentions,
6974
+ repo_rope_args=repo_rope_args,
6975
+ position_ids=position_ids,
6976
+ **kwargs,
6977
+ )
6978
+ hidden_states = layer_outputs[0]
6979
+ siamese_y_states = layer_outputs[1]
6980
+ extras_start = 2
6981
+ else:
6982
+ layer_outputs = decoder_layer(
6983
+ hidden_states,
6984
+ position_embeddings=position_embeddings,
6985
+ attention_mask=causal_mask,
6986
+ first_layer_fan=self.first_layer_fan,
6987
+ z_tilde=z_tilde,
6988
+ B_vals=B_vals,
6989
+ dca_sources=dca_layer_sources,
6990
+ attn_res_sources=attn_res_sources,
6991
+ attn_res_partial=attn_res_partial if use_attn_res else None,
6992
+ layer_analysis=layer_analysis,
6993
+ output_attentions=output_attentions,
6994
+ repo_rope_args=repo_rope_args,
6995
+ position_ids=position_ids,
6996
+ **kwargs,
6997
+ )
6998
+ hidden_states = layer_outputs[0]
6999
+ extras_start = 1
7000
 
7001
  # Update AttnRes partial sum β€” the new partial is the layer output
7002
  if use_attn_res:
 
7007
  dca_sources = dca_sources + [hidden_states]
7008
 
7009
  if output_attentions:
7010
+ all_attentions = all_attentions + (layer_outputs[extras_start],)
7011
+ extras_start += 1
7012
 
7013
+ # Collect JTok-M aux stats.
7014
+ if self.config.use_jtokm:
7015
+ for item in layer_outputs[extras_start:]:
7016
+ if isinstance(item, tuple) and len(item) == 3:
7017
+ all_aux_stats.append(item)
7018
+ break
7019
 
7020
+ # Collect VersatileFFN aux stats. Only non-None during training.
 
7021
  if getattr(self.config, "use_versatile_ffn", False):
7022
+ for item in layer_outputs[extras_start:]:
7023
  if isinstance(item, tuple) and len(item) == 3:
 
7024
  all_aux_stats.append(("versatile", item))
7025
  break
7026
 
 
7029
  and hasattr(decoder_layer, "current_layer_fan")):
7030
  self.first_layer_fan = decoder_layer.current_layer_fan
7031
 
7032
+ if use_siamesenorm:
7033
+ hidden_states = self.siamese_final_norm(
7034
+ self.siamese_x_final_norm(hidden_states)
7035
+ + self.siamese_y_final_norm(siamese_y_states)
7036
+ )
7037
+ else:
7038
+ hidden_states = self.norm(hidden_states)
7039
 
7040
  if output_hidden_states:
7041
  all_hidden_states = all_hidden_states + (hidden_states,)