Update modeling_neollm.py
Browse files- modeling_neollm.py +278 -60
modeling_neollm.py
CHANGED
|
@@ -5458,7 +5458,13 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5458 |
self.hidden_size = config.hidden_size
|
| 5459 |
self.layer_idx = layer_idx
|
| 5460 |
self.use_jtokm = config.use_jtokm
|
| 5461 |
-
self.use_seednorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5462 |
# Controls only the first pre-attention normalisation applied directly
|
| 5463 |
# to the embedding stream. Defaults to True for checkpoint/config
|
| 5464 |
# backward compatibility. When False, layer 0 does not instantiate
|
|
@@ -5467,7 +5473,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5467 |
self.use_embedding_input_norm = bool(
|
| 5468 |
getattr(config, "use_embedding_input_norm", True)
|
| 5469 |
)
|
| 5470 |
-
self.has_input_layernorm = not (
|
| 5471 |
self.layer_idx == 0 and not self.use_embedding_input_norm
|
| 5472 |
)
|
| 5473 |
|
|
@@ -5478,24 +5484,55 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5478 |
else NeoLLMMLP(config)
|
| 5479 |
)
|
| 5480 |
self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False)
|
| 5481 |
-
self.
|
| 5482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5483 |
config.hidden_size,
|
| 5484 |
eps=config.rms_norm_eps,
|
| 5485 |
use_seednorm=self.use_seednorm,
|
| 5486 |
)
|
| 5487 |
-
|
| 5488 |
-
|
| 5489 |
-
|
| 5490 |
-
|
| 5491 |
-
|
| 5492 |
-
|
| 5493 |
-
|
| 5494 |
-
|
| 5495 |
-
self.lns_attn = LNS(layer_idx)
|
| 5496 |
-
self.lns_mlp = LNS(layer_idx)
|
| 5497 |
-
self.gpas_attn = GPAS(config.hidden_size)
|
| 5498 |
-
self.gpas_mlp = GPAS(config.hidden_size)
|
| 5499 |
self.current_layer_fan = None
|
| 5500 |
|
| 5501 |
# ββ StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ββββββββ
|
|
@@ -5750,6 +5787,129 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5750 |
# LAUREL-LR (paper eq. 3): f(x) + BAx + x
|
| 5751 |
return delta + lr_delta + residual
|
| 5752 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5753 |
def forward(
|
| 5754 |
self,
|
| 5755 |
hidden_states: torch.Tensor,
|
|
@@ -5830,12 +5990,16 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5830 |
k_normed = _apply_norm(self.input_layernorm, dca_k_input, analysis=None)
|
| 5831 |
v_normed = _apply_norm(self.input_layernorm, dca_v_input, analysis=None)
|
| 5832 |
|
| 5833 |
-
h_lns
|
| 5834 |
if dca_k_input is not None:
|
| 5835 |
-
|
| 5836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5837 |
dca_key_value_states = (k_lns, v_lns)
|
| 5838 |
-
if layer_analysis is not None:
|
| 5839 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 5840 |
|
| 5841 |
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
|
@@ -5870,7 +6034,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5870 |
else:
|
| 5871 |
attn_aug = residual_attn + hidden_states
|
| 5872 |
dca_final_residual = hidden_states if self.use_dca else None
|
| 5873 |
-
h_tilde =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5874 |
|
| 5875 |
if layer_analysis is not None:
|
| 5876 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
@@ -5892,8 +6060,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5892 |
# ββ MLP block βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5893 |
sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
|
| 5894 |
h_normed2 = _apply_norm(self.post_attention_layernorm, h_mlp, analysis=sn_post)
|
| 5895 |
-
h_lns2
|
| 5896 |
-
if layer_analysis is not None:
|
| 5897 |
layer_analysis.lns_mlp_output = h_lns2.detach()
|
| 5898 |
|
| 5899 |
mlp_a = layer_analysis.mlp if layer_analysis is not None else None
|
|
@@ -5934,11 +6102,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 5934 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 5935 |
delta_r = delta_r.reshape(orig_shape)
|
| 5936 |
|
| 5937 |
-
gpas_mlp_a
|
| 5938 |
-
hidden_states =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5939 |
else:
|
| 5940 |
-
gpas_mlp_a
|
| 5941 |
-
hidden_states =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5942 |
|
| 5943 |
if layer_analysis is not None:
|
| 5944 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
@@ -6378,6 +6554,8 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 6378 |
|
| 6379 |
module.res_weight.data.fill_(1.0)
|
| 6380 |
elif isinstance(module, NeoLLMDecoderLayer):
|
|
|
|
|
|
|
| 6381 |
# AttnRes pseudo-queries: MUST be initialized to zero.
|
| 6382 |
# Zero initialization ensures uniform attention weights at step 0
|
| 6383 |
# (softmax of zeros is uniform), making AttnRes equivalent to a
|
|
@@ -6468,11 +6646,21 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6468 |
[NeoLLMDecoderLayer(config, layer_idx)
|
| 6469 |
for layer_idx in range(config.num_hidden_layers)]
|
| 6470 |
)
|
| 6471 |
-
self.
|
| 6472 |
-
|
| 6473 |
-
|
| 6474 |
-
|
| 6475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6476 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 6477 |
self.gradient_checkpointing = False
|
| 6478 |
self.first_layer_fan = None if getattr(config, "use_fan_residual", True) else False
|
|
@@ -6564,8 +6752,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6564 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 6565 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 6566 |
),
|
| 6567 |
-
gpas_attn = GPASAnalysis(),
|
| 6568 |
-
gpas_mlp = GPASAnalysis(),
|
| 6569 |
jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
|
| 6570 |
dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
|
| 6571 |
attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
|
|
@@ -6658,6 +6846,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6658 |
)
|
| 6659 |
|
| 6660 |
hidden_states = inputs_embeds
|
|
|
|
|
|
|
| 6661 |
all_hidden_states = () if output_hidden_states else None
|
| 6662 |
all_attentions = () if output_attentions else None
|
| 6663 |
all_aux_stats = []
|
|
@@ -6770,23 +6960,43 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6770 |
|
| 6771 |
dca_layer_sources = self._select_dca_sources(dca_sources) if use_dca else None
|
| 6772 |
|
| 6773 |
-
|
| 6774 |
-
|
| 6775 |
-
|
| 6776 |
-
|
| 6777 |
-
|
| 6778 |
-
|
| 6779 |
-
|
| 6780 |
-
|
| 6781 |
-
|
| 6782 |
-
|
| 6783 |
-
|
| 6784 |
-
|
| 6785 |
-
|
| 6786 |
-
|
| 6787 |
-
|
| 6788 |
-
|
| 6789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6790 |
|
| 6791 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 6792 |
if use_attn_res:
|
|
@@ -6797,18 +7007,20 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6797 |
dca_sources = dca_sources + [hidden_states]
|
| 6798 |
|
| 6799 |
if output_attentions:
|
| 6800 |
-
all_attentions = all_attentions + (layer_outputs[
|
|
|
|
| 6801 |
|
| 6802 |
-
# Collect JTok-M aux stats
|
| 6803 |
-
if self.config.use_jtokm
|
| 6804 |
-
|
|
|
|
|
|
|
|
|
|
| 6805 |
|
| 6806 |
-
# Collect VersatileFFN aux stats
|
| 6807 |
-
# or last if jtokm is absent). Only non-None during training.
|
| 6808 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 6809 |
-
for item in layer_outputs[
|
| 6810 |
if isinstance(item, tuple) and len(item) == 3:
|
| 6811 |
-
# (p_sum, f_sum, N_tokens) signature
|
| 6812 |
all_aux_stats.append(("versatile", item))
|
| 6813 |
break
|
| 6814 |
|
|
@@ -6817,7 +7029,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 6817 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 6818 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 6819 |
|
| 6820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6821 |
|
| 6822 |
if output_hidden_states:
|
| 6823 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
| 5458 |
self.hidden_size = config.hidden_size
|
| 5459 |
self.layer_idx = layer_idx
|
| 5460 |
self.use_jtokm = config.use_jtokm
|
| 5461 |
+
self.use_seednorm = bool(getattr(config, "use_seednorm", True))
|
| 5462 |
+
self.use_lns = bool(getattr(config, "use_lns", False))
|
| 5463 |
+
self.use_gpas = bool(getattr(config, "use_gpas", False))
|
| 5464 |
+
self.use_siamesenorm = bool(getattr(config, "use_siamesenorm", False))
|
| 5465 |
+
self.siamese_normalized_input = bool(getattr(config, "siamese_normalized_input", True))
|
| 5466 |
+
self.siamese_depth_scaling = bool(getattr(config, "siamese_depth_scaling", True))
|
| 5467 |
+
self.siamese_attn_x_scale_init = float(getattr(config, "siamese_attn_x_scale_init", 1.0))
|
| 5468 |
# Controls only the first pre-attention normalisation applied directly
|
| 5469 |
# to the embedding stream. Defaults to True for checkpoint/config
|
| 5470 |
# backward compatibility. When False, layer 0 does not instantiate
|
|
|
|
| 5473 |
self.use_embedding_input_norm = bool(
|
| 5474 |
getattr(config, "use_embedding_input_norm", True)
|
| 5475 |
)
|
| 5476 |
+
self.has_input_layernorm = (not self.use_siamesenorm) and not (
|
| 5477 |
self.layer_idx == 0 and not self.use_embedding_input_norm
|
| 5478 |
)
|
| 5479 |
|
|
|
|
| 5484 |
else NeoLLMMLP(config)
|
| 5485 |
)
|
| 5486 |
self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False)
|
| 5487 |
+
if self.use_siamesenorm:
|
| 5488 |
+
self.input_layernorm = None
|
| 5489 |
+
self.post_attention_layernorm = None
|
| 5490 |
+
|
| 5491 |
+
# SiameseNorm is RMS-only by config validation. These modules are
|
| 5492 |
+
# constructed only when the Siamese topology is active, so no
|
| 5493 |
+
# inactive SeeDNorm/RMSNorm pre-norm modules remain in the graph.
|
| 5494 |
+
self.siamese_attn_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5495 |
+
self.siamese_attn_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5496 |
+
self.siamese_mlp_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5497 |
+
self.siamese_mlp_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5498 |
+
self.siamese_attn_input_norm = (
|
| 5499 |
+
nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5500 |
+
if self.siamese_normalized_input else None
|
| 5501 |
+
)
|
| 5502 |
+
self.siamese_mlp_input_norm = (
|
| 5503 |
+
nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 5504 |
+
if self.siamese_normalized_input else None
|
| 5505 |
+
)
|
| 5506 |
+
self.siamese_attn_x_scale = nn.Parameter(
|
| 5507 |
+
torch.full((config.hidden_size,), self.siamese_attn_x_scale_init, dtype=torch.float32)
|
| 5508 |
+
)
|
| 5509 |
+
else:
|
| 5510 |
+
self.input_layernorm = (
|
| 5511 |
+
_make_norm(
|
| 5512 |
+
config.hidden_size,
|
| 5513 |
+
eps=config.rms_norm_eps,
|
| 5514 |
+
use_seednorm=self.use_seednorm,
|
| 5515 |
+
)
|
| 5516 |
+
if self.has_input_layernorm
|
| 5517 |
+
else None
|
| 5518 |
+
)
|
| 5519 |
+
self.post_attention_layernorm = _make_norm(
|
| 5520 |
config.hidden_size,
|
| 5521 |
eps=config.rms_norm_eps,
|
| 5522 |
use_seednorm=self.use_seednorm,
|
| 5523 |
)
|
| 5524 |
+
self.siamese_attn_x_norm = None
|
| 5525 |
+
self.siamese_attn_y_norm = None
|
| 5526 |
+
self.siamese_mlp_x_norm = None
|
| 5527 |
+
self.siamese_mlp_y_norm = None
|
| 5528 |
+
self.siamese_attn_input_norm = None
|
| 5529 |
+
self.siamese_mlp_input_norm = None
|
| 5530 |
+
self.siamese_attn_x_scale = None
|
| 5531 |
+
|
| 5532 |
+
self.lns_attn = LNS(layer_idx) if self.use_lns else None
|
| 5533 |
+
self.lns_mlp = LNS(layer_idx) if self.use_lns else None
|
| 5534 |
+
self.gpas_attn = GPAS(config.hidden_size) if self.use_gpas else None
|
| 5535 |
+
self.gpas_mlp = GPAS(config.hidden_size) if self.use_gpas else None
|
| 5536 |
self.current_layer_fan = None
|
| 5537 |
|
| 5538 |
# ββ StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ββββββββ
|
|
|
|
| 5787 |
# LAUREL-LR (paper eq. 3): f(x) + BAx + x
|
| 5788 |
return delta + lr_delta + residual
|
| 5789 |
|
| 5790 |
+
def _siamese_stream_scale(self, ref: torch.Tensor) -> torch.Tensor:
|
| 5791 |
+
if not self.siamese_depth_scaling:
|
| 5792 |
+
return ref.new_tensor(1.0)
|
| 5793 |
+
return ref.new_tensor(1.0 / math.sqrt(2.0 * float(self.layer_idx + 1)))
|
| 5794 |
+
|
| 5795 |
+
def forward_siamesenorm(
|
| 5796 |
+
self,
|
| 5797 |
+
x_states: torch.Tensor,
|
| 5798 |
+
y_states: torch.Tensor,
|
| 5799 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 5800 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 5801 |
+
first_layer_fan: Optional[torch.Tensor] = None,
|
| 5802 |
+
z_tilde: Optional[torch.Tensor] = None,
|
| 5803 |
+
B_vals: Optional[torch.Tensor] = None,
|
| 5804 |
+
layer_analysis: Optional[LayerAnalysis] = None,
|
| 5805 |
+
output_attentions: Optional[bool] = False,
|
| 5806 |
+
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 5807 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 5808 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 5809 |
+
) -> Tuple:
|
| 5810 |
+
# SiameseNorm keeps two coupled streams with shared Attention/MLP
|
| 5811 |
+
# parameters. All Siamese normalization modules are RMSNorm by
|
| 5812 |
+
# construction; SeeDNorm is rejected at config validation time.
|
| 5813 |
+
if layer_analysis is not None:
|
| 5814 |
+
layer_analysis.hidden_states_input = x_states.detach()
|
| 5815 |
+
|
| 5816 |
+
# ββ Attention shared block ββββββββββββββββββββββββββββββββββββββββ
|
| 5817 |
+
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
|
| 5818 |
+
x_attn_norm = self.siamese_attn_x_norm(x_states)
|
| 5819 |
+
y_attn_norm = self.siamese_attn_y_norm(y_states)
|
| 5820 |
+
if sn_pre is not None:
|
| 5821 |
+
sn_pre.output = x_attn_norm.detach()
|
| 5822 |
+
|
| 5823 |
+
x_scale = self.siamese_attn_x_scale.to(dtype=x_attn_norm.dtype, device=x_attn_norm.device)
|
| 5824 |
+
h_attn = x_scale * x_attn_norm + y_attn_norm
|
| 5825 |
+
if self.siamese_attn_input_norm is not None:
|
| 5826 |
+
h_attn = self.siamese_attn_input_norm(h_attn)
|
| 5827 |
+
|
| 5828 |
+
h_lns = self.lns_attn(h_attn) if self.use_lns else h_attn
|
| 5829 |
+
if layer_analysis is not None and self.use_lns:
|
| 5830 |
+
layer_analysis.lns_attn_output = h_lns.detach()
|
| 5831 |
+
|
| 5832 |
+
attn_out, attn_weights, self.current_layer_fan = self.self_attn(
|
| 5833 |
+
hidden_states=h_lns,
|
| 5834 |
+
key_value_states=None,
|
| 5835 |
+
attention_mask=attention_mask,
|
| 5836 |
+
position_embeddings=position_embeddings,
|
| 5837 |
+
first_layer_fan=first_layer_fan,
|
| 5838 |
+
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 5839 |
+
repo_rope_args=repo_rope_args,
|
| 5840 |
+
position_ids=position_ids,
|
| 5841 |
+
**kwargs,
|
| 5842 |
+
)
|
| 5843 |
+
|
| 5844 |
+
if layer_analysis is not None:
|
| 5845 |
+
layer_analysis.attn_contribution = attn_out.detach()
|
| 5846 |
+
|
| 5847 |
+
stream_scale = self._siamese_stream_scale(attn_out)
|
| 5848 |
+
x_after_attn = x_states + stream_scale * attn_out
|
| 5849 |
+
y_after_attn = y_states + attn_out
|
| 5850 |
+
|
| 5851 |
+
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
|
| 5852 |
+
if self.use_gpas:
|
| 5853 |
+
x_after_attn = self.gpas_attn(x_after_attn, analysis=gpas_attn_a)
|
| 5854 |
+
|
| 5855 |
+
if layer_analysis is not None:
|
| 5856 |
+
layer_analysis.h_tilde = x_after_attn.detach()
|
| 5857 |
+
|
| 5858 |
+
# ββ MLP shared block ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5859 |
+
sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
|
| 5860 |
+
x_mlp_norm = self.siamese_mlp_x_norm(x_after_attn)
|
| 5861 |
+
y_mlp_norm = self.siamese_mlp_y_norm(y_after_attn)
|
| 5862 |
+
if sn_post is not None:
|
| 5863 |
+
sn_post.output = x_mlp_norm.detach()
|
| 5864 |
+
|
| 5865 |
+
h_mlp = x_mlp_norm + y_mlp_norm
|
| 5866 |
+
if self.siamese_mlp_input_norm is not None:
|
| 5867 |
+
h_mlp = self.siamese_mlp_input_norm(h_mlp)
|
| 5868 |
+
|
| 5869 |
+
h_lns2 = self.lns_mlp(h_mlp) if self.use_lns else h_mlp
|
| 5870 |
+
if layer_analysis is not None and self.use_lns:
|
| 5871 |
+
layer_analysis.lns_mlp_output = h_lns2.detach()
|
| 5872 |
+
|
| 5873 |
+
mlp_a = layer_analysis.mlp if layer_analysis is not None else None
|
| 5874 |
+
if self.use_versatile_ffn:
|
| 5875 |
+
delta_m, versatile_aux = self.mlp(h_lns2, analysis=mlp_a)
|
| 5876 |
+
else:
|
| 5877 |
+
delta_m = self.mlp(h_lns2, analysis=mlp_a)
|
| 5878 |
+
versatile_aux = None
|
| 5879 |
+
|
| 5880 |
+
if layer_analysis is not None:
|
| 5881 |
+
layer_analysis.mlp_contribution = delta_m.detach()
|
| 5882 |
+
|
| 5883 |
+
shared_update = delta_m
|
| 5884 |
+
aux_stats = None
|
| 5885 |
+
if self.use_jtokm and z_tilde is not None and B_vals is not None:
|
| 5886 |
+
orig_shape = x_after_attn.shape
|
| 5887 |
+
h_flat = x_after_attn.reshape(-1, self.hidden_size)
|
| 5888 |
+
z_flat = z_tilde.reshape(-1, z_tilde.shape[-1])
|
| 5889 |
+
B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1])
|
| 5890 |
+
jtokm_a = layer_analysis.jtokm if layer_analysis is not None else None
|
| 5891 |
+
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 5892 |
+
shared_update = shared_update + delta_r.reshape(orig_shape)
|
| 5893 |
+
|
| 5894 |
+
x_next = x_after_attn + stream_scale * shared_update
|
| 5895 |
+
y_next = y_after_attn + shared_update
|
| 5896 |
+
|
| 5897 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 5898 |
+
if self.use_gpas:
|
| 5899 |
+
x_next = self.gpas_mlp(x_next, analysis=gpas_mlp_a)
|
| 5900 |
+
|
| 5901 |
+
if layer_analysis is not None:
|
| 5902 |
+
layer_analysis.hidden_states_output = x_next.detach()
|
| 5903 |
+
|
| 5904 |
+
outputs = (x_next, y_next)
|
| 5905 |
+
if output_attentions:
|
| 5906 |
+
outputs += (attn_weights,)
|
| 5907 |
+
if aux_stats is not None:
|
| 5908 |
+
outputs += (aux_stats,)
|
| 5909 |
+
if versatile_aux is not None:
|
| 5910 |
+
outputs += (versatile_aux,)
|
| 5911 |
+
return outputs
|
| 5912 |
+
|
| 5913 |
def forward(
|
| 5914 |
self,
|
| 5915 |
hidden_states: torch.Tensor,
|
|
|
|
| 5990 |
k_normed = _apply_norm(self.input_layernorm, dca_k_input, analysis=None)
|
| 5991 |
v_normed = _apply_norm(self.input_layernorm, dca_v_input, analysis=None)
|
| 5992 |
|
| 5993 |
+
h_lns = self.lns_attn(h_normed) if self.use_lns else h_normed
|
| 5994 |
if dca_k_input is not None:
|
| 5995 |
+
if self.use_lns:
|
| 5996 |
+
k_lns = self.lns_attn(k_normed)
|
| 5997 |
+
v_lns = self.lns_attn(v_normed)
|
| 5998 |
+
else:
|
| 5999 |
+
k_lns = k_normed
|
| 6000 |
+
v_lns = v_normed
|
| 6001 |
dca_key_value_states = (k_lns, v_lns)
|
| 6002 |
+
if layer_analysis is not None and self.use_lns:
|
| 6003 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 6004 |
|
| 6005 |
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
|
|
|
| 6034 |
else:
|
| 6035 |
attn_aug = residual_attn + hidden_states
|
| 6036 |
dca_final_residual = hidden_states if self.use_dca else None
|
| 6037 |
+
h_tilde = (
|
| 6038 |
+
self.gpas_attn(attn_aug, analysis=gpas_attn_a)
|
| 6039 |
+
if self.use_gpas
|
| 6040 |
+
else attn_aug
|
| 6041 |
+
)
|
| 6042 |
|
| 6043 |
if layer_analysis is not None:
|
| 6044 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
|
|
| 6060 |
# ββ MLP block βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6061 |
sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
|
| 6062 |
h_normed2 = _apply_norm(self.post_attention_layernorm, h_mlp, analysis=sn_post)
|
| 6063 |
+
h_lns2 = self.lns_mlp(h_normed2) if self.use_lns else h_normed2
|
| 6064 |
+
if layer_analysis is not None and self.use_lns:
|
| 6065 |
layer_analysis.lns_mlp_output = h_lns2.detach()
|
| 6066 |
|
| 6067 |
mlp_a = layer_analysis.mlp if layer_analysis is not None else None
|
|
|
|
| 6102 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 6103 |
delta_r = delta_r.reshape(orig_shape)
|
| 6104 |
|
| 6105 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 6106 |
+
hidden_states = (
|
| 6107 |
+
self.gpas_mlp(mlp_aug + delta_r, analysis=gpas_mlp_a)
|
| 6108 |
+
if self.use_gpas
|
| 6109 |
+
else mlp_aug + delta_r
|
| 6110 |
+
)
|
| 6111 |
else:
|
| 6112 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 6113 |
+
hidden_states = (
|
| 6114 |
+
self.gpas_mlp(mlp_aug, analysis=gpas_mlp_a)
|
| 6115 |
+
if self.use_gpas
|
| 6116 |
+
else mlp_aug
|
| 6117 |
+
)
|
| 6118 |
|
| 6119 |
if layer_analysis is not None:
|
| 6120 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
|
|
| 6554 |
|
| 6555 |
module.res_weight.data.fill_(1.0)
|
| 6556 |
elif isinstance(module, NeoLLMDecoderLayer):
|
| 6557 |
+
if hasattr(module, "siamese_attn_x_scale") and module.siamese_attn_x_scale is not None:
|
| 6558 |
+
module.siamese_attn_x_scale.data.fill_(module.siamese_attn_x_scale_init)
|
| 6559 |
# AttnRes pseudo-queries: MUST be initialized to zero.
|
| 6560 |
# Zero initialization ensures uniform attention weights at step 0
|
| 6561 |
# (softmax of zeros is uniform), making AttnRes equivalent to a
|
|
|
|
| 6646 |
[NeoLLMDecoderLayer(config, layer_idx)
|
| 6647 |
for layer_idx in range(config.num_hidden_layers)]
|
| 6648 |
)
|
| 6649 |
+
self.use_siamesenorm = bool(getattr(config, "use_siamesenorm", False))
|
| 6650 |
+
if self.use_siamesenorm:
|
| 6651 |
+
self.norm = None
|
| 6652 |
+
self.siamese_x_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 6653 |
+
self.siamese_y_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 6654 |
+
self.siamese_final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 6655 |
+
else:
|
| 6656 |
+
self.norm = _make_norm(
|
| 6657 |
+
config.hidden_size,
|
| 6658 |
+
eps=config.rms_norm_eps,
|
| 6659 |
+
use_seednorm=bool(getattr(config, "use_seednorm", True)),
|
| 6660 |
+
)
|
| 6661 |
+
self.siamese_x_final_norm = None
|
| 6662 |
+
self.siamese_y_final_norm = None
|
| 6663 |
+
self.siamese_final_norm = None
|
| 6664 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 6665 |
self.gradient_checkpointing = False
|
| 6666 |
self.first_layer_fan = None if getattr(config, "use_fan_residual", True) else False
|
|
|
|
| 6752 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 6753 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 6754 |
),
|
| 6755 |
+
gpas_attn = GPASAnalysis() if getattr(cfg, "use_gpas", False) else None,
|
| 6756 |
+
gpas_mlp = GPASAnalysis() if getattr(cfg, "use_gpas", False) else None,
|
| 6757 |
jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
|
| 6758 |
dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
|
| 6759 |
attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
|
|
|
|
| 6846 |
)
|
| 6847 |
|
| 6848 |
hidden_states = inputs_embeds
|
| 6849 |
+
use_siamesenorm = bool(getattr(self.config, "use_siamesenorm", False))
|
| 6850 |
+
siamese_y_states = inputs_embeds if use_siamesenorm else None
|
| 6851 |
all_hidden_states = () if output_hidden_states else None
|
| 6852 |
all_attentions = () if output_attentions else None
|
| 6853 |
all_aux_stats = []
|
|
|
|
| 6960 |
|
| 6961 |
dca_layer_sources = self._select_dca_sources(dca_sources) if use_dca else None
|
| 6962 |
|
| 6963 |
+
if use_siamesenorm:
|
| 6964 |
+
layer_outputs = decoder_layer.forward_siamesenorm(
|
| 6965 |
+
hidden_states,
|
| 6966 |
+
siamese_y_states,
|
| 6967 |
+
position_embeddings=position_embeddings,
|
| 6968 |
+
attention_mask=causal_mask,
|
| 6969 |
+
first_layer_fan=self.first_layer_fan,
|
| 6970 |
+
z_tilde=z_tilde,
|
| 6971 |
+
B_vals=B_vals,
|
| 6972 |
+
layer_analysis=layer_analysis,
|
| 6973 |
+
output_attentions=output_attentions,
|
| 6974 |
+
repo_rope_args=repo_rope_args,
|
| 6975 |
+
position_ids=position_ids,
|
| 6976 |
+
**kwargs,
|
| 6977 |
+
)
|
| 6978 |
+
hidden_states = layer_outputs[0]
|
| 6979 |
+
siamese_y_states = layer_outputs[1]
|
| 6980 |
+
extras_start = 2
|
| 6981 |
+
else:
|
| 6982 |
+
layer_outputs = decoder_layer(
|
| 6983 |
+
hidden_states,
|
| 6984 |
+
position_embeddings=position_embeddings,
|
| 6985 |
+
attention_mask=causal_mask,
|
| 6986 |
+
first_layer_fan=self.first_layer_fan,
|
| 6987 |
+
z_tilde=z_tilde,
|
| 6988 |
+
B_vals=B_vals,
|
| 6989 |
+
dca_sources=dca_layer_sources,
|
| 6990 |
+
attn_res_sources=attn_res_sources,
|
| 6991 |
+
attn_res_partial=attn_res_partial if use_attn_res else None,
|
| 6992 |
+
layer_analysis=layer_analysis,
|
| 6993 |
+
output_attentions=output_attentions,
|
| 6994 |
+
repo_rope_args=repo_rope_args,
|
| 6995 |
+
position_ids=position_ids,
|
| 6996 |
+
**kwargs,
|
| 6997 |
+
)
|
| 6998 |
+
hidden_states = layer_outputs[0]
|
| 6999 |
+
extras_start = 1
|
| 7000 |
|
| 7001 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 7002 |
if use_attn_res:
|
|
|
|
| 7007 |
dca_sources = dca_sources + [hidden_states]
|
| 7008 |
|
| 7009 |
if output_attentions:
|
| 7010 |
+
all_attentions = all_attentions + (layer_outputs[extras_start],)
|
| 7011 |
+
extras_start += 1
|
| 7012 |
|
| 7013 |
+
# Collect JTok-M aux stats.
|
| 7014 |
+
if self.config.use_jtokm:
|
| 7015 |
+
for item in layer_outputs[extras_start:]:
|
| 7016 |
+
if isinstance(item, tuple) and len(item) == 3:
|
| 7017 |
+
all_aux_stats.append(item)
|
| 7018 |
+
break
|
| 7019 |
|
| 7020 |
+
# Collect VersatileFFN aux stats. Only non-None during training.
|
|
|
|
| 7021 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 7022 |
+
for item in layer_outputs[extras_start:]:
|
| 7023 |
if isinstance(item, tuple) and len(item) == 3:
|
|
|
|
| 7024 |
all_aux_stats.append(("versatile", item))
|
| 7025 |
break
|
| 7026 |
|
|
|
|
| 7029 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 7030 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 7031 |
|
| 7032 |
+
if use_siamesenorm:
|
| 7033 |
+
hidden_states = self.siamese_final_norm(
|
| 7034 |
+
self.siamese_x_final_norm(hidden_states)
|
| 7035 |
+
+ self.siamese_y_final_norm(siamese_y_states)
|
| 7036 |
+
)
|
| 7037 |
+
else:
|
| 7038 |
+
hidden_states = self.norm(hidden_states)
|
| 7039 |
|
| 7040 |
if output_hidden_states:
|
| 7041 |
all_hidden_states = all_hidden_states + (hidden_states,)
|