Update modeling_neollm.py
Browse files- modeling_neollm.py +1647 -56
modeling_neollm.py
CHANGED
|
@@ -77,10 +77,34 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
|
|
| 77 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 78 |
from transformers.processing_utils import Unpack
|
| 79 |
from transformers.utils import TransformersKwargs, logging
|
| 80 |
-
from
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
logger = logging.get_logger(__name__)
|
| 85 |
|
| 86 |
|
|
@@ -339,6 +363,22 @@ class JTokMAnalysis:
|
|
| 339 |
lns_scale: Optional[float] = None # 1/β(2β) scaling factor
|
| 340 |
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
@dataclass
|
| 343 |
class AttnResAnalysis:
|
| 344 |
"""
|
|
@@ -350,6 +390,44 @@ class AttnResAnalysis:
|
|
| 350 |
sources_count: Optional[int] = None # number of sources including partial
|
| 351 |
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
@dataclass
|
| 354 |
class LayerAnalysis:
|
| 355 |
"""
|
|
@@ -378,8 +456,12 @@ class LayerAnalysis:
|
|
| 378 |
gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
|
| 379 |
|
| 380 |
# Optional components (None when inactive)
|
| 381 |
-
jtokm:
|
| 382 |
-
attn_res:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
|
| 385 |
@dataclass
|
|
@@ -444,6 +526,7 @@ class AnalysisState:
|
|
| 444 |
layers: Optional[List[LayerAnalysis]] = None
|
| 445 |
jtokm_aux_stats: Optional[list] = None
|
| 446 |
attn_res_sources_final: Optional[list] = None
|
|
|
|
| 447 |
logits: Optional[torch.Tensor] = None
|
| 448 |
|
| 449 |
class ScalarMultiplier(nn.Module):
|
|
@@ -2363,6 +2446,8 @@ class NeoLLMAttention(nn.Module):
|
|
| 2363 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2364 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
| 2365 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
|
|
|
|
|
| 2366 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2367 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2368 |
input_shape = hidden_states.shape[:-1]
|
|
@@ -2373,6 +2458,14 @@ class NeoLLMAttention(nn.Module):
|
|
| 2373 |
h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
|
| 2374 |
current_layer_fan = h_fan.clone()
|
| 2375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2376 |
query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
|
| 2377 |
kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
|
| 2378 |
|
|
@@ -2387,8 +2480,8 @@ class NeoLLMAttention(nn.Module):
|
|
| 2387 |
attn_analysis.gate_raw = gate.detach()
|
| 2388 |
|
| 2389 |
q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
|
| 2390 |
-
k = self.k_norm(self.k_proj(
|
| 2391 |
-
v = self.v_proj(
|
| 2392 |
|
| 2393 |
if attn_analysis is not None:
|
| 2394 |
attn_analysis.q_post_norm = q.detach()
|
|
@@ -3065,6 +3158,1036 @@ class NeoLLMMLP(nn.Module):
|
|
| 3065 |
return result
|
| 3066 |
|
| 3067 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3068 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 3069 |
"""
|
| 3070 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
@@ -3087,7 +4210,23 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3087 |
self.layer_idx = layer_idx
|
| 3088 |
self.use_jtokm = config.use_jtokm
|
| 3089 |
|
| 3090 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3091 |
self.mlp = (
|
| 3092 |
VersatileFFN(config)
|
| 3093 |
if getattr(config, "use_versatile_ffn", False)
|
|
@@ -3120,10 +4259,78 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3120 |
self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3121 |
self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3122 |
self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3123 |
else:
|
| 3124 |
self.attn_res_query_attn = None
|
| 3125 |
self.attn_res_query_mlp = None
|
| 3126 |
self.attn_res_norm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3127 |
|
| 3128 |
def _attn_res(
|
| 3129 |
self,
|
|
@@ -3173,6 +4380,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3173 |
B_vals: Optional[torch.Tensor] = None,
|
| 3174 |
attn_res_sources: Optional[list] = None,
|
| 3175 |
attn_res_partial: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3176 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 3177 |
output_attentions: Optional[bool] = False,
|
| 3178 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
@@ -3182,6 +4393,63 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3182 |
if layer_analysis is not None:
|
| 3183 |
layer_analysis.hidden_states_input = hidden_states.detach()
|
| 3184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3185 |
# ββ Attention Residuals: compute pre-attention input ββββββββββββββ
|
| 3186 |
# When active, the input to the attention sublayer is no longer the
|
| 3187 |
# raw hidden_states (accumulated residual) but a softmax-weighted
|
|
@@ -3195,10 +4463,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3195 |
attn_res_sources, attn_res_partial, self.attn_res_query_attn,
|
| 3196 |
ar_analysis, "attn",
|
| 3197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3198 |
residual_attn = attn_res_partial
|
| 3199 |
else:
|
| 3200 |
-
h_attn =
|
| 3201 |
-
|
|
|
|
| 3202 |
|
| 3203 |
# ββ Attention block βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3204 |
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
|
|
@@ -3207,21 +4484,49 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3207 |
if layer_analysis is not None:
|
| 3208 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 3209 |
|
| 3210 |
-
|
| 3211 |
-
|
| 3212 |
-
|
| 3213 |
-
|
| 3214 |
-
|
| 3215 |
-
|
| 3216 |
-
|
| 3217 |
-
|
| 3218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3219 |
|
| 3220 |
if layer_analysis is not None:
|
| 3221 |
layer_analysis.attn_contribution = hidden_states.detach()
|
| 3222 |
|
| 3223 |
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
|
| 3224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3225 |
|
| 3226 |
if layer_analysis is not None:
|
| 3227 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
@@ -3257,8 +4562,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3257 |
if layer_analysis is not None:
|
| 3258 |
layer_analysis.mlp_contribution = delta_m.detach()
|
| 3259 |
|
| 3260 |
-
# ββ
|
| 3261 |
-
|
|
|
|
|
|
|
|
|
|
| 3262 |
if self.use_jtokm and z_tilde is not None and B_vals is not None:
|
| 3263 |
orig_shape = h_tilde.shape
|
| 3264 |
h_flat = h_tilde.reshape(-1, self.hidden_size)
|
|
@@ -3269,11 +4577,21 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3269 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 3270 |
delta_r = delta_r.reshape(orig_shape)
|
| 3271 |
|
| 3272 |
-
|
| 3273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3274 |
else:
|
| 3275 |
-
|
| 3276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3277 |
|
| 3278 |
if layer_analysis is not None:
|
| 3279 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
@@ -3285,6 +4603,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3285 |
outputs += (aux_stats,)
|
| 3286 |
if versatile_aux is not None:
|
| 3287 |
outputs += (versatile_aux,)
|
|
|
|
|
|
|
|
|
|
| 3288 |
return outputs
|
| 3289 |
|
| 3290 |
|
|
@@ -3612,6 +4933,9 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 3612 |
if hasattr(module, "alpha_ma"):
|
| 3613 |
module.alpha_ma.zero_()
|
| 3614 |
|
|
|
|
|
|
|
|
|
|
| 3615 |
elif isinstance(module, GPAS):
|
| 3616 |
module.alpha.data.fill_(0.0)
|
| 3617 |
|
|
@@ -3668,6 +4992,45 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 3668 |
module.attn_res_query_attn.data.zero_()
|
| 3669 |
module.attn_res_query_mlp.data.zero_()
|
| 3670 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3671 |
elif isinstance(module, SpellingBeeEmbedding):
|
| 3672 |
# byte_emb initialised identically to token embeddings: std=1/βd.
|
| 3673 |
# Ensures E[βe_byteβΒ²] β 1 at init, matching etok, so the
|
|
@@ -3737,8 +5100,99 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3737 |
self.gradient_checkpointing = False
|
| 3738 |
self.first_layer_fan = None
|
| 3739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3740 |
self.post_init()
|
| 3741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3742 |
def get_input_embeddings(self):
|
| 3743 |
if self.config.use_token_generator:
|
| 3744 |
return self.token_generator
|
|
@@ -3764,7 +5218,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3764 |
getattr(cfg, "use_repo", False)
|
| 3765 |
and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
|
| 3766 |
)
|
| 3767 |
-
_versatile
|
|
|
|
|
|
|
| 3768 |
return LayerAnalysis(
|
| 3769 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 3770 |
seednorm_post_attn = SeeDNormAnalysis(),
|
|
@@ -3778,10 +5234,14 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3778 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 3779 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 3780 |
),
|
| 3781 |
-
gpas_attn
|
| 3782 |
-
gpas_mlp
|
| 3783 |
-
jtokm
|
| 3784 |
-
attn_res
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3785 |
)
|
| 3786 |
|
| 3787 |
def forward(
|
|
@@ -3883,6 +5343,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3883 |
if getattr(self.config, "use_repo", False) else None
|
| 3884 |
)
|
| 3885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3886 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 3887 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 3888 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
@@ -3897,13 +5361,57 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3897 |
if use_attn_res:
|
| 3898 |
attn_res_sources = [hidden_states] # b_0 = token embedding
|
| 3899 |
attn_res_partial = hidden_states # initial partial sum
|
| 3900 |
-
|
| 3901 |
-
|
| 3902 |
-
|
| 3903 |
-
|
| 3904 |
-
|
| 3905 |
-
|
| 3906 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3907 |
|
| 3908 |
# Pre-allocate per-layer analysis list when analysis is active
|
| 3909 |
if analysis_state is not None:
|
|
@@ -3913,17 +5421,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3913 |
if output_hidden_states:
|
| 3914 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 3915 |
|
| 3916 |
-
# ββ
|
| 3917 |
-
#
|
| 3918 |
-
#
|
| 3919 |
-
|
| 3920 |
-
|
| 3921 |
-
|
| 3922 |
-
|
| 3923 |
-
# re-seeded from the previous hidden_states below.
|
| 3924 |
-
if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0:
|
| 3925 |
-
attn_res_sources = attn_res_sources + [attn_res_partial]
|
| 3926 |
-
attn_res_partial = hidden_states # start new block from current output
|
| 3927 |
|
| 3928 |
# Build per-layer analysis container (only in eval + analysis mode)
|
| 3929 |
layer_analysis = None
|
|
@@ -3932,15 +5436,26 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3932 |
layer_analysis.layer_idx = layer_idx
|
| 3933 |
analysis_state.layers.append(layer_analysis)
|
| 3934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3935 |
layer_outputs = decoder_layer(
|
| 3936 |
hidden_states,
|
| 3937 |
position_embeddings=position_embeddings,
|
| 3938 |
-
attention_mask=
|
| 3939 |
first_layer_fan=self.first_layer_fan,
|
| 3940 |
z_tilde=z_tilde,
|
| 3941 |
B_vals=B_vals,
|
| 3942 |
attn_res_sources=attn_res_sources,
|
| 3943 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3944 |
layer_analysis=layer_analysis,
|
| 3945 |
output_attentions=output_attentions,
|
| 3946 |
repo_rope_args=repo_rope_args,
|
|
@@ -3948,23 +5463,76 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3948 |
)
|
| 3949 |
hidden_states = layer_outputs[0]
|
| 3950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3951 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 3952 |
if use_attn_res:
|
| 3953 |
attn_res_partial = hidden_states
|
| 3954 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3955 |
if output_attentions:
|
| 3956 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 3957 |
|
| 3958 |
-
# Collect JTok-M
|
| 3959 |
-
|
| 3960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3961 |
|
| 3962 |
-
# Collect VersatileFFN aux stats (second-to-last if jtokm also present,
|
| 3963 |
-
# or last if jtokm is absent). Only non-None during training.
|
| 3964 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 3965 |
-
for item in
|
| 3966 |
if isinstance(item, tuple) and len(item) == 3:
|
| 3967 |
-
# (p_sum, f_sum, N_tokens) signature
|
| 3968 |
all_aux_stats.append(("versatile", item))
|
| 3969 |
break
|
| 3970 |
|
|
@@ -3972,6 +5540,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3972 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 3973 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 3974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3975 |
hidden_states = self.norm(hidden_states)
|
| 3976 |
|
| 3977 |
if output_hidden_states:
|
|
@@ -3984,6 +5562,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 3984 |
analysis_state.attn_res_sources_final = (
|
| 3985 |
attn_res_sources if use_attn_res else None
|
| 3986 |
)
|
|
|
|
|
|
|
|
|
|
| 3987 |
|
| 3988 |
if not return_dict:
|
| 3989 |
return tuple(
|
|
@@ -4124,6 +5705,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 4124 |
layers = None, # filled by NeoLLMModel.forward
|
| 4125 |
jtokm_aux_stats = [] if cfg.use_jtokm else None,
|
| 4126 |
attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
|
|
|
|
| 4127 |
)
|
| 4128 |
|
| 4129 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -4261,6 +5843,12 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 4261 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 4262 |
|
| 4263 |
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4264 |
"NeoLLMForCausalLM",
|
| 4265 |
"NeoLLMModel",
|
| 4266 |
"NeoLLMPreTrainedModel",
|
|
@@ -4278,7 +5866,7 @@ __all__ = [
|
|
| 4278 |
"REPOModule",
|
| 4279 |
"VersatileFFN",
|
| 4280 |
"compute_versatile_aux_loss",
|
| 4281 |
-
# Analysis dataclasses
|
| 4282 |
"AnalysisState",
|
| 4283 |
"LayerAnalysis",
|
| 4284 |
"AttentionAnalysis",
|
|
@@ -4292,6 +5880,9 @@ __all__ = [
|
|
| 4292 |
"VersatileFFNAnalysis",
|
| 4293 |
"JTokMAnalysis",
|
| 4294 |
"AttnResAnalysis",
|
|
|
|
|
|
|
|
|
|
| 4295 |
"GeneratorAnalysis",
|
| 4296 |
]
|
| 4297 |
|
|
|
|
| 77 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 78 |
from transformers.processing_utils import Unpack
|
| 79 |
from transformers.utils import TransformersKwargs, logging
|
| 80 |
+
from configuration_neollm import NeoLLMConfig
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
|
| 84 |
+
# ββ Optional fast-path dependencies (GatedDeltaNet linear attention) βββββββββ
|
| 85 |
+
try:
|
| 86 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update as _causal_conv1d_update
|
| 87 |
+
_causal_conv1d_available = True
|
| 88 |
+
except ImportError:
|
| 89 |
+
causal_conv1d_fn = None
|
| 90 |
+
_causal_conv1d_update = None
|
| 91 |
+
_causal_conv1d_available = False
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
from fla.modules import FusedRMSNormGated
|
| 95 |
+
from fla.ops.gated_delta_rule import (
|
| 96 |
+
chunk_gated_delta_rule,
|
| 97 |
+
fused_recurrent_gated_delta_rule,
|
| 98 |
+
)
|
| 99 |
+
_fla_available = True
|
| 100 |
+
except ImportError:
|
| 101 |
+
FusedRMSNormGated = None
|
| 102 |
+
chunk_gated_delta_rule = None
|
| 103 |
+
fused_recurrent_gated_delta_rule = None
|
| 104 |
+
_fla_available = False
|
| 105 |
+
|
| 106 |
+
is_linear_attn_fast_path = _causal_conv1d_available and _fla_available
|
| 107 |
+
|
| 108 |
logger = logging.get_logger(__name__)
|
| 109 |
|
| 110 |
|
|
|
|
| 363 |
lns_scale: Optional[float] = None # 1/β(2β) scaling factor
|
| 364 |
|
| 365 |
|
| 366 |
+
@dataclass
|
| 367 |
+
class DCAAnalysis:
|
| 368 |
+
"""
|
| 369 |
+
GRN-v3 depth-wise aggregate weights from a DeepCrossAttention layer.
|
| 370 |
+
Only populated when use_dca=True.
|
| 371 |
+
|
| 372 |
+
grn_depth_weights: softmax-free aggregate scalars used to weight each
|
| 373 |
+
source layer, shape [3, y, B, S] where 3 = Q/K/V streams,
|
| 374 |
+
y = selected stack depth (at most 2*dca_k), B = batch, S = seq.
|
| 375 |
+
These are the per-position, per-layer scalars *before* adding the
|
| 376 |
+
static bias β useful to see which layers the dynamic component
|
| 377 |
+
selectively suppresses (ReLU zeros out negative entries).
|
| 378 |
+
"""
|
| 379 |
+
grn_depth_weights: Optional[torch.Tensor] = None # [3, y, B, S]
|
| 380 |
+
|
| 381 |
+
|
| 382 |
@dataclass
|
| 383 |
class AttnResAnalysis:
|
| 384 |
"""
|
|
|
|
| 390 |
sources_count: Optional[int] = None # number of sources including partial
|
| 391 |
|
| 392 |
|
| 393 |
+
@dataclass
|
| 394 |
+
class StackMemoryAnalysis:
|
| 395 |
+
"""
|
| 396 |
+
Internals of a StackMemory forward pass.
|
| 397 |
+
Only populated when use_stacktrans=True AND model is in eval + analysis mode.
|
| 398 |
+
|
| 399 |
+
Reference: Zhang, K. et al. (NeurIPS 2025). "Recursive Transformer:
|
| 400 |
+
Boosting Reasoning Ability with State Stack."
|
| 401 |
+
|
| 402 |
+
action_probs: softmax distribution [push, pop, no-op] per head and
|
| 403 |
+
token position. Shape [B, S, H, 3]. Visualising this
|
| 404 |
+
across layers reveals the push-heavy early layers and
|
| 405 |
+
pop-heavy later layers described in the paper (Β§B.2).
|
| 406 |
+
stack_in: stack state entering this layer (the output of the
|
| 407 |
+
previous layer's StackMemory). Shape [B, H, slots, ds].
|
| 408 |
+
None for layer 0 (starts as all-zeros).
|
| 409 |
+
stack_out: updated stack state after processing this sequence.
|
| 410 |
+
Shape [B, H, slots, ds]. This is new_stack[:, -1] β
|
| 411 |
+
the stack at the final sequence position, passed to
|
| 412 |
+
the next layer as stack_in.
|
| 413 |
+
mask_out: validity mask for stack_out. Shape [B, H, slots].
|
| 414 |
+
Values near 1 indicate active slots; near 0 = empty.
|
| 415 |
+
gate_weights: softmax attention weights used for global reading.
|
| 416 |
+
Shape [B, S, H, slots]. High weight on slot i at
|
| 417 |
+
position t means the model retrieved from slot i there.
|
| 418 |
+
memory_output: weighted stack readout before up_proj.
|
| 419 |
+
Shape [B, S, stack_d_model].
|
| 420 |
+
residual_scale: value of the learnable res_weight scalar at this step.
|
| 421 |
+
"""
|
| 422 |
+
action_probs: Optional[torch.Tensor] = None # [B,S,H,3]
|
| 423 |
+
stack_in: Optional[torch.Tensor] = None # [B,H,slots,ds] entering layer
|
| 424 |
+
stack_out: Optional[torch.Tensor] = None # [B,H,slots,ds] leaving layer
|
| 425 |
+
mask_out: Optional[torch.Tensor] = None # [B,H,slots]
|
| 426 |
+
gate_weights: Optional[torch.Tensor] = None # [B,S,H,slots]
|
| 427 |
+
memory_output: Optional[torch.Tensor] = None # [B,S,stack_d_model]
|
| 428 |
+
residual_scale: Optional[float] = None # res_weight scalar
|
| 429 |
+
|
| 430 |
+
|
| 431 |
@dataclass
|
| 432 |
class LayerAnalysis:
|
| 433 |
"""
|
|
|
|
| 456 |
gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
|
| 457 |
|
| 458 |
# Optional components (None when inactive)
|
| 459 |
+
jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
|
| 460 |
+
attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
|
| 461 |
+
dca: Optional[DCAAnalysis] = None # if use_dca
|
| 462 |
+
stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
|
| 463 |
+
laurel_attn: Optional["LAuReLAnalysis"] = None # if use_laurel (attention residual)
|
| 464 |
+
laurel_mlp: Optional["LAuReLAnalysis"] = None # if use_laurel (MLP residual)
|
| 465 |
|
| 466 |
|
| 467 |
@dataclass
|
|
|
|
| 526 |
layers: Optional[List[LayerAnalysis]] = None
|
| 527 |
jtokm_aux_stats: Optional[list] = None
|
| 528 |
attn_res_sources_final: Optional[list] = None
|
| 529 |
+
dca_all_tokens_final: Optional[list] = None
|
| 530 |
logits: Optional[torch.Tensor] = None
|
| 531 |
|
| 532 |
class ScalarMultiplier(nn.Module):
|
|
|
|
| 2446 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2447 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
| 2448 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 2449 |
+
mudd_xk: Optional[torch.Tensor] = None,
|
| 2450 |
+
mudd_xv: Optional[torch.Tensor] = None,
|
| 2451 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2452 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2453 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 2458 |
h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
|
| 2459 |
current_layer_fan = h_fan.clone()
|
| 2460 |
|
| 2461 |
+
# ββ MUDD: separate K/V FAN paths βββββββββββββββββββββββββββββββββ
|
| 2462 |
+
# When mudd_xk/mudd_xv are provided (MUDD qkvr mode), they have already
|
| 2463 |
+
# been normalized by the decoder layer's K/V norm chain. Here they go
|
| 2464 |
+
# through their own FAN transform before k_proj/v_proj, keeping the
|
| 2465 |
+
# FANformer periodicity modeling orthogonally intact per stream.
|
| 2466 |
+
h_fan_k = self.fan_layer(mudd_xk) if mudd_xk is not None else h_fan
|
| 2467 |
+
h_fan_v = self.fan_layer(mudd_xv) if mudd_xv is not None else h_fan
|
| 2468 |
+
|
| 2469 |
query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
|
| 2470 |
kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
|
| 2471 |
|
|
|
|
| 2480 |
attn_analysis.gate_raw = gate.detach()
|
| 2481 |
|
| 2482 |
q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
|
| 2483 |
+
k = self.k_norm(self.k_proj(h_fan_k).view(kv_shape)).transpose(1, 2)
|
| 2484 |
+
v = self.v_proj(h_fan_v).view(kv_shape).transpose(1, 2)
|
| 2485 |
|
| 2486 |
if attn_analysis is not None:
|
| 2487 |
attn_analysis.q_post_norm = q.detach()
|
|
|
|
| 3158 |
return result
|
| 3159 |
|
| 3160 |
|
| 3161 |
+
class NeoLLMMUDDModule(nn.Module):
|
| 3162 |
+
"""
|
| 3163 |
+
Multiway Dynamic Dense (MUDD) Depth-wise Aggregate module.
|
| 3164 |
+
|
| 3165 |
+
Generates per-position, per-stream connection weights over all preceding
|
| 3166 |
+
layer outputs (and the token embedding) and produces up to C=4 aggregated
|
| 3167 |
+
streams (Q, K, V, R) for the next Transformer block.
|
| 3168 |
+
|
| 3169 |
+
Architecture (Xiao et al., 2025, arXiv:2502.12170):
|
| 3170 |
+
dw = GELU(RMSNorm(x) @ W1) @ W2 + a # [B, T, C*(lidx+2)]
|
| 3171 |
+
dw = reshape to [C, B, T, (lidx+2)]
|
| 3172 |
+
stream_c = Ξ£_j dw[c, :, :, j] * hiddens[j] for c in range(C)
|
| 3173 |
+
|
| 3174 |
+
W1 ~ N(0, 1/D), W2 = 0, a = identity on last index β reduces to standard
|
| 3175 |
+
Transformer at init (dynamic part is zero, static bias selects Xi).
|
| 3176 |
+
|
| 3177 |
+
Args:
|
| 3178 |
+
hidden_size: model dimension D
|
| 3179 |
+
lidx: layer index (0-based); history has lidx+2 entries
|
| 3180 |
+
num_ways: C, number of output streams (4 for "qkvr", 1 for "l")
|
| 3181 |
+
is_last: whether this is the last layer (controls expand_last)
|
| 3182 |
+
expand_last: multiply hid_dim by 4 for the final layer's DA module
|
| 3183 |
+
round64: round hid_dim up to the nearest multiple of 64
|
| 3184 |
+
"""
|
| 3185 |
+
|
| 3186 |
+
def __init__(
|
| 3187 |
+
self,
|
| 3188 |
+
hidden_size: int,
|
| 3189 |
+
lidx: int,
|
| 3190 |
+
num_ways: int = 4,
|
| 3191 |
+
is_last: bool = False,
|
| 3192 |
+
expand_last: bool = False,
|
| 3193 |
+
round64: bool = False,
|
| 3194 |
+
) -> None:
|
| 3195 |
+
super().__init__()
|
| 3196 |
+
self.lidx = lidx
|
| 3197 |
+
self.num_ways = num_ways
|
| 3198 |
+
l = lidx + 2 # history length: embedding + lidx layers
|
| 3199 |
+
hid_dim = l * num_ways
|
| 3200 |
+
out_dim = l * num_ways
|
| 3201 |
+
if is_last and expand_last:
|
| 3202 |
+
hid_dim *= 4
|
| 3203 |
+
if round64:
|
| 3204 |
+
hid_dim = (hid_dim // 64 + 1) * 64
|
| 3205 |
+
# RMSNorm without learnable scale (paper uses RMSnormNoscale)
|
| 3206 |
+
self.norm = nn.RMSNorm(hidden_size, elementwise_affine=False,
|
| 3207 |
+
eps=1e-6)
|
| 3208 |
+
self.w1 = nn.Linear(hidden_size, hid_dim, bias=False)
|
| 3209 |
+
self.act = nn.GELU()
|
| 3210 |
+
self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
|
| 3211 |
+
self._reset_mudd_parameters(hidden_size)
|
| 3212 |
+
|
| 3213 |
+
def _reset_mudd_parameters(self, D: int) -> None:
|
| 3214 |
+
# W1 ~ N(0, 1/D); W2 = 0 β dynamic part starts at zero
|
| 3215 |
+
nn.init.normal_(self.w1.weight, mean=0.0, std=1.0 / D)
|
| 3216 |
+
nn.init.zeros_(self.w2.weight)
|
| 3217 |
+
|
| 3218 |
+
def forward(
|
| 3219 |
+
self,
|
| 3220 |
+
x: torch.Tensor, # [B, T, D] β current layer output (Xi)
|
| 3221 |
+
hiddens: list, # list of lidx+2 tensors [B, T, D]
|
| 3222 |
+
static_bias: torch.Tensor, # [C, lidx+2] β learnable static prior
|
| 3223 |
+
) -> tuple:
|
| 3224 |
+
"""
|
| 3225 |
+
Returns:
|
| 3226 |
+
Tuple of num_ways tensors, each [B, T, D] β the aggregated streams.
|
| 3227 |
+
"""
|
| 3228 |
+
B, T, D = x.shape
|
| 3229 |
+
# Dynamic weight generation: [B, T, C*(lidx+2)]
|
| 3230 |
+
dw = self.w2(self.act(self.w1(self.norm(x))))
|
| 3231 |
+
# Add static bias (broadcast over B and T)
|
| 3232 |
+
# static_bias: [C, L] β [1, 1, C*L] via reshape
|
| 3233 |
+
C, L = static_bias.shape
|
| 3234 |
+
dw = dw + static_bias.reshape(1, 1, C * L).to(dw.dtype)
|
| 3235 |
+
# Reshape to [C, B, T, L]
|
| 3236 |
+
dw = dw.view(B, T, C, L).permute(2, 0, 1, 3) # [C, B, T, L]
|
| 3237 |
+
# Stack history: [L, B, T, D]
|
| 3238 |
+
stacked = torch.stack(hiddens, dim=0) # [L, B, T, D]
|
| 3239 |
+
# Aggregate: Ξ£_j dw[c, :, :, j] * hiddens[j]
|
| 3240 |
+
# einsum "cbtl, lbtd -> cbtd"
|
| 3241 |
+
streams = torch.einsum('cbtl,lbtd->cbtd', dw, stacked) # [C, B, T, D]
|
| 3242 |
+
return tuple(streams[c] for c in range(C))
|
| 3243 |
+
|
| 3244 |
+
|
| 3245 |
+
def dca_select_layers(stacked: torch.Tensor, k: int) -> torch.Tensor:
|
| 3246 |
+
"""
|
| 3247 |
+
k-DCA layer selection (Heddes et al., 2025, Β§3.1).
|
| 3248 |
+
|
| 3249 |
+
Keeps only the first k and last k tensors from the depth stack,
|
| 3250 |
+
capping memory at 2k layer representations regardless of depth.
|
| 3251 |
+
When the stack has <= 2k entries all are kept (early layers).
|
| 3252 |
+
|
| 3253 |
+
Args:
|
| 3254 |
+
stacked: [y, B, S, D] β stack of all layer outputs so far.
|
| 3255 |
+
k: number of first/last layers to retain.
|
| 3256 |
+
Returns:
|
| 3257 |
+
[min(y, 2k), B, S, D]
|
| 3258 |
+
"""
|
| 3259 |
+
y = stacked.shape[0]
|
| 3260 |
+
if y <= k * 2:
|
| 3261 |
+
return stacked
|
| 3262 |
+
return torch.cat([stacked[:k], stacked[-k:]], dim=0)
|
| 3263 |
+
|
| 3264 |
+
|
| 3265 |
+
class NeoLLMGRN(nn.Module):
|
| 3266 |
+
"""
|
| 3267 |
+
Generalized Residual Network v3 (GRN-v3) from DeepCrossAttention
|
| 3268 |
+
(Heddes et al., 2025, arXiv:2502.06785, Β§3.1).
|
| 3269 |
+
|
| 3270 |
+
Produces `num_outputs` aggregated streams from a depth-wise stack of
|
| 3271 |
+
layer representations. Weights are simultaneously:
|
| 3272 |
+
|
| 3273 |
+
- **Input-dependent** (dynamic): a two-layer mapping
|
| 3274 |
+
``wΜ = ReLU(RMSNorm(G) @ W)`` produces one scalar per
|
| 3275 |
+
(output-stream, depth-position, batch-token). ``W`` is initialized
|
| 3276 |
+
to zero so the dynamic contribution starts neutral.
|
| 3277 |
+
- **Dimension-dependent** (static): a learnable bias ``b`` of shape
|
| 3278 |
+
``[num_outputs, num_stack_layers, hidden_size]`` initialized to ones
|
| 3279 |
+
provides a per-dimension, per-layer prior. At initialization the
|
| 3280 |
+
dynamic part is zero and the static bias sums to an equal-weight
|
| 3281 |
+
average over all stack entries, reducing to a standard residual mean.
|
| 3282 |
+
|
| 3283 |
+
The combined weight for output stream ``o``, stack position ``y``,
|
| 3284 |
+
batch ``b``, token ``n``, feature ``d`` is::
|
| 3285 |
+
|
| 3286 |
+
weight[o, y, b, n, d] = ReLU(dynamic[y, b, n, o]) + bias[o, y, d]
|
| 3287 |
+
|
| 3288 |
+
Output ``o`` is then the weighted sum over depth::
|
| 3289 |
+
|
| 3290 |
+
out[o, b, n, d] = Ξ£_y stack[y, b, n, d] * weight[o, y, b, n, d]
|
| 3291 |
+
|
| 3292 |
+
Reference:
|
| 3293 |
+
Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
|
| 3294 |
+
Transformer Residual Connections.* arXiv:2502.06785.
|
| 3295 |
+
|
| 3296 |
+
Args:
|
| 3297 |
+
hidden_size: model dimension D.
|
| 3298 |
+
num_stack_layers: number of depth entries this GRN will receive
|
| 3299 |
+
(= min(layer_idx+1, 2*dca_k)).
|
| 3300 |
+
num_outputs: number of output streams (3 for DCA Q/K/V,
|
| 3301 |
+
1 for the final aggregation GRN).
|
| 3302 |
+
eps: epsilon for the internal RMSNorm.
|
| 3303 |
+
"""
|
| 3304 |
+
|
| 3305 |
+
def __init__(
|
| 3306 |
+
self,
|
| 3307 |
+
hidden_size: int,
|
| 3308 |
+
num_stack_layers: int,
|
| 3309 |
+
num_outputs: int = 3,
|
| 3310 |
+
eps: float = 1e-6,
|
| 3311 |
+
) -> None:
|
| 3312 |
+
super().__init__()
|
| 3313 |
+
self.num_outputs = num_outputs
|
| 3314 |
+
self.num_stack_layers = num_stack_layers
|
| 3315 |
+
|
| 3316 |
+
# Dynamic component: RMSNorm(no scale) β Linear β ReLU
|
| 3317 |
+
# Linear maps D β num_outputs; init zeros so dynamic part = 0 at step 0.
|
| 3318 |
+
_linear = nn.Linear(hidden_size, num_outputs, bias=False)
|
| 3319 |
+
nn.init.zeros_(_linear.weight)
|
| 3320 |
+
self.norm_noscale = nn.RMSNorm(
|
| 3321 |
+
hidden_size, eps=eps, elementwise_affine=False
|
| 3322 |
+
)
|
| 3323 |
+
self.to_dynamic = nn.Sequential(_linear, nn.ReLU())
|
| 3324 |
+
|
| 3325 |
+
# Static bias: [num_outputs, num_stack_layers, hidden_size], init ones.
|
| 3326 |
+
# At init: weight = 0 + bias = 1 per entry β equal-weight average β residual.
|
| 3327 |
+
self.bias = nn.Parameter(
|
| 3328 |
+
torch.ones(num_outputs, num_stack_layers, hidden_size)
|
| 3329 |
+
)
|
| 3330 |
+
|
| 3331 |
+
def forward(
|
| 3332 |
+
self,
|
| 3333 |
+
stack: torch.Tensor,
|
| 3334 |
+
analysis: Optional["DCAAnalysis"] = None,
|
| 3335 |
+
) -> tuple:
|
| 3336 |
+
"""
|
| 3337 |
+
Args:
|
| 3338 |
+
stack: [y, B, S, D] β selected depth stack (y β€ 2*dca_k).
|
| 3339 |
+
analysis: optional DCAAnalysis to deposit grn_depth_weights.
|
| 3340 |
+
Returns:
|
| 3341 |
+
Tuple of num_outputs tensors each [B, S, D].
|
| 3342 |
+
When num_outputs=1 returns a single [B, S, D] tensor directly.
|
| 3343 |
+
"""
|
| 3344 |
+
y, B, S, D = stack.shape
|
| 3345 |
+
assert y == self.num_stack_layers, (
|
| 3346 |
+
f"NeoLLMGRN expected stack depth {self.num_stack_layers}, got {y}"
|
| 3347 |
+
)
|
| 3348 |
+
|
| 3349 |
+
# Dynamic aggregate: [y, B, S, D] β norm β [y, B, S, D]
|
| 3350 |
+
# β to_dynamic β [y, B, S, num_outputs]
|
| 3351 |
+
# β permute β [num_outputs, y, B, S]
|
| 3352 |
+
normed = self.norm_noscale(stack) # [y, B, S, D]
|
| 3353 |
+
dynamic = self.to_dynamic(normed) # [y, B, S, num_outputs]
|
| 3354 |
+
dynamic = dynamic.permute(3, 0, 1, 2) # [o, y, B, S]
|
| 3355 |
+
|
| 3356 |
+
if analysis is not None:
|
| 3357 |
+
analysis.grn_depth_weights = dynamic.detach()
|
| 3358 |
+
|
| 3359 |
+
# Combined weight: dynamic scalar + static bias per dimension
|
| 3360 |
+
# dynamic: [o, y, B, S] β [o, y, B, S, 1]
|
| 3361 |
+
# bias: [o, y, D] β [o, y, 1, 1, D]
|
| 3362 |
+
weights = dynamic.unsqueeze(-1) + self.bias.unsqueeze(2).unsqueeze(3)
|
| 3363 |
+
# weights: [o, y, B, S, D]
|
| 3364 |
+
|
| 3365 |
+
# Weighted depth-sum: Ξ£_y stack[y] * weights[o, y]
|
| 3366 |
+
# stack: [y, B, S, D] β [1, y, B, S, D]
|
| 3367 |
+
output = (stack.unsqueeze(0) * weights).sum(dim=1) # [o, B, S, D]
|
| 3368 |
+
|
| 3369 |
+
if self.num_outputs == 1:
|
| 3370 |
+
return output.squeeze(0) # [B, S, D]
|
| 3371 |
+
return tuple(output[i] for i in range(self.num_outputs))
|
| 3372 |
+
|
| 3373 |
+
|
| 3374 |
+
class StackMemory(nn.Module):
|
| 3375 |
+
"""
|
| 3376 |
+
Differentiable multi-head hidden-state stack for NeoLLM.
|
| 3377 |
+
|
| 3378 |
+
Implements the StackTrans module from Zhang et al. (NeurIPS 2025):
|
| 3379 |
+
"Recursive Transformer: Boosting Reasoning Ability with State Stack."
|
| 3380 |
+
|
| 3381 |
+
Architecture (one forward call, covering the full sequence in parallel):
|
| 3382 |
+
|
| 3383 |
+
1. down_proj : [B,S,D] β [B,S,stack_d_model]
|
| 3384 |
+
2. action_head: β [B,S,H,3] softmax (push / pop / no-op)
|
| 3385 |
+
3. k_values : reshape to [B,S,H,ds]
|
| 3386 |
+
4. _vectorized_update: applies soft push/pop/no-op to each
|
| 3387 |
+
(batch, head) stack in parallel across the sequence dim.
|
| 3388 |
+
This is the training-parallelism approximation from Β§3.3:
|
| 3389 |
+
every token sees the *same* initial stack, breaking strict
|
| 3390 |
+
temporal ordering within a sequence in exchange for full
|
| 3391 |
+
data-parallelism. Cross-token memory is recovered during
|
| 3392 |
+
autoregressive generation via the step() / enable_cache path.
|
| 3393 |
+
5. gate_proj : global read β softmax over all stack slots
|
| 3394 |
+
(paper Β§3.1: "query-over-stack attention"), masked by the
|
| 3395 |
+
validity mask. Returns weighted sum of the stack.
|
| 3396 |
+
6. up_proj : [B,S,stack_d_model] β [B,S,D]
|
| 3397 |
+
7. residual : output = up_proj_out * res_weight + hidden_states
|
| 3398 |
+
|
| 3399 |
+
Vertical passing (layer-to-layer):
|
| 3400 |
+
Returns new_stack[:, -1] and new_mask[:, -1] β the stack state
|
| 3401 |
+
at the last sequence position β which becomes the initial stack
|
| 3402 |
+
for the next decoder layer. This propagates hierarchical context
|
| 3403 |
+
depth-wise through the network.
|
| 3404 |
+
|
| 3405 |
+
Temporal accumulation (generation):
|
| 3406 |
+
During autoregressive decoding, enable_cache=True and step() is
|
| 3407 |
+
used: k_cache and action_cache store previous-token values so the
|
| 3408 |
+
update equation integrates the full generated history rather than
|
| 3409 |
+
starting from zeros each step.
|
| 3410 |
+
|
| 3411 |
+
Args:
|
| 3412 |
+
config: NeoLLMConfig instance. Reads:
|
| 3413 |
+
stacktrans_num_heads (H, number of stack heads)
|
| 3414 |
+
stacktrans_stack_slots (S, stack depth)
|
| 3415 |
+
stacktrans_stack_d_model (HΓds, low-rank dimension)
|
| 3416 |
+
stacktrans_forward_bs (batch size for cache buffers)
|
| 3417 |
+
"""
|
| 3418 |
+
|
| 3419 |
+
def __init__(self, config: NeoLLMConfig):
|
| 3420 |
+
super().__init__()
|
| 3421 |
+
self.num_stack_heads = config.stacktrans_num_heads
|
| 3422 |
+
self.stack_slots = config.stacktrans_stack_slots
|
| 3423 |
+
self.stack_d_model = config.stacktrans_stack_d_model
|
| 3424 |
+
self.head_dim = self.stack_d_model // self.num_stack_heads
|
| 3425 |
+
|
| 3426 |
+
# Dimension reduction / expansion (standard nn.Linear, no multipliers β
|
| 3427 |
+
# StackMemory is architecturally independent per the paper Β§A)
|
| 3428 |
+
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
|
| 3429 |
+
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
|
| 3430 |
+
|
| 3431 |
+
# Action prediction: push / pop / no-op probabilities, one triple per head
|
| 3432 |
+
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
| 3433 |
+
|
| 3434 |
+
# Global read query: one scalar gate per stack slot per head
|
| 3435 |
+
self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
|
| 3436 |
+
|
| 3437 |
+
# Learnable residual gate (paper h'_t = g_hΒ·h_t + R_t, g_h scalar)
|
| 3438 |
+
self.res_weight = nn.Parameter(torch.ones(1))
|
| 3439 |
+
|
| 3440 |
+
# ββ Autoregressive generation cache ββββββββββββββββββββββββββββββ
|
| 3441 |
+
# k_cache and action_cache hold per-token values from previous steps
|
| 3442 |
+
# so step() can reconstruct the full sequence history. Only used when
|
| 3443 |
+
# enable_cache=True (set by NeoLLMModel.forward when use_cache=True).
|
| 3444 |
+
_fbs = getattr(config, "stacktrans_forward_bs", 1)
|
| 3445 |
+
_cs = getattr(config, "cache_size", 2048)
|
| 3446 |
+
self.register_buffer(
|
| 3447 |
+
"k_cache",
|
| 3448 |
+
torch.zeros(_fbs, _cs, self.num_stack_heads, self.head_dim),
|
| 3449 |
+
)
|
| 3450 |
+
self.register_buffer(
|
| 3451 |
+
"action_cache",
|
| 3452 |
+
torch.zeros(_fbs, _cs, self.num_stack_heads, 3),
|
| 3453 |
+
)
|
| 3454 |
+
self.cache_position = 0
|
| 3455 |
+
self.enable_cache = False
|
| 3456 |
+
|
| 3457 |
+
# ββ Cache helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3458 |
+
|
| 3459 |
+
def reset_cache(self) -> None:
|
| 3460 |
+
self.cache_position = 0
|
| 3461 |
+
|
| 3462 |
+
def _update_cache(
|
| 3463 |
+
self,
|
| 3464 |
+
k_values: torch.Tensor, # [B,S,H,ds] detached
|
| 3465 |
+
actions: torch.Tensor, # [B,S,H,3] detached
|
| 3466 |
+
) -> None:
|
| 3467 |
+
seq_len = k_values.shape[1]
|
| 3468 |
+
if self.cache_position + seq_len <= self.k_cache.shape[1]:
|
| 3469 |
+
self.k_cache [:, self.cache_position:self.cache_position + seq_len] = k_values
|
| 3470 |
+
self.action_cache[:, self.cache_position:self.cache_position + seq_len] = actions
|
| 3471 |
+
self.cache_position += seq_len
|
| 3472 |
+
else:
|
| 3473 |
+
self.reset_cache()
|
| 3474 |
+
|
| 3475 |
+
# ββ Core stack update βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3476 |
+
|
| 3477 |
+
def _vectorized_update(
|
| 3478 |
+
self,
|
| 3479 |
+
stack: torch.Tensor, # [B, H, slots, ds] (4-D) or [B,S,H,slots,ds] (5-D)
|
| 3480 |
+
mask: torch.Tensor, # [B, H, slots] (3-D) or [B,S,H,slots] (4-D)
|
| 3481 |
+
actions: torch.Tensor, # [B, S, H, 3]
|
| 3482 |
+
k_values: torch.Tensor, # [B, S, H, ds]
|
| 3483 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 3484 |
+
"""
|
| 3485 |
+
Vectorized soft push/pop/no-op stack update.
|
| 3486 |
+
|
| 3487 |
+
Every token position receives the *same* initial stack (the one
|
| 3488 |
+
passed in from the previous layer), and operations are applied in
|
| 3489 |
+
parallel across S. This is the Β§3.3 training-parallelism
|
| 3490 |
+
approximation: strict sequential dependency within a sequence is
|
| 3491 |
+
broken intentionally to allow full batch processing.
|
| 3492 |
+
|
| 3493 |
+
Returns:
|
| 3494 |
+
new_stack [B, S, H, slots, ds]
|
| 3495 |
+
new_mask [B, S, H, slots]
|
| 3496 |
+
"""
|
| 3497 |
+
batch_size, seq_len = actions.shape[:2]
|
| 3498 |
+
|
| 3499 |
+
# Broadcast 4-D initial state along the sequence dimension
|
| 3500 |
+
if stack.dim() == 4:
|
| 3501 |
+
stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
|
| 3502 |
+
mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
|
| 3503 |
+
|
| 3504 |
+
# Push: new value at top, shift everything down (overflow discarded)
|
| 3505 |
+
push_stack = torch.cat([k_values.unsqueeze(3), stack[:, :, :, :-1]], dim=3)
|
| 3506 |
+
push_mask = torch.cat([torch.ones_like(mask[:, :, :, :1]),
|
| 3507 |
+
mask[:, :, :, :-1]], dim=3)
|
| 3508 |
+
|
| 3509 |
+
# Pop: shift everything up, zero at bottom
|
| 3510 |
+
pop_stack = torch.cat([stack[:, :, :, 1:],
|
| 3511 |
+
torch.zeros_like(stack[:, :, :, :1])], dim=3)
|
| 3512 |
+
pop_mask = torch.cat([mask[:, :, :, 1:],
|
| 3513 |
+
torch.zeros_like(mask[:, :, :, :1])], dim=3)
|
| 3514 |
+
|
| 3515 |
+
# Soft combination weighted by action probabilities
|
| 3516 |
+
# actions: [B,S,H,3] β unsqueeze to [B,S,H,3,1,1] for stack broadcast
|
| 3517 |
+
aw = actions.unsqueeze(-1).unsqueeze(-1) # [B,S,H,3,1,1]
|
| 3518 |
+
stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [B,S,H,3,slots,ds]
|
| 3519 |
+
masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [B,S,H,3,slots]
|
| 3520 |
+
|
| 3521 |
+
new_stack = (stacks * aw).sum(dim=3) # [B,S,H,slots,ds]
|
| 3522 |
+
new_mask = (masks * aw.squeeze(-1)).sum(dim=3) # [B,S,H,slots]
|
| 3523 |
+
return new_stack, new_mask
|
| 3524 |
+
|
| 3525 |
+
# ββ Training forward (full sequence) βββββββββββββββββββββββββββββββββ
|
| 3526 |
+
|
| 3527 |
+
def forward(
|
| 3528 |
+
self,
|
| 3529 |
+
hidden_states: torch.Tensor,
|
| 3530 |
+
stack: Optional[torch.Tensor] = None,
|
| 3531 |
+
mask: Optional[torch.Tensor] = None,
|
| 3532 |
+
analysis: Optional[StackMemoryAnalysis] = None,
|
| 3533 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 3534 |
+
"""
|
| 3535 |
+
Full-sequence forward pass (training and prefill).
|
| 3536 |
+
|
| 3537 |
+
Args:
|
| 3538 |
+
hidden_states: [B, S, D]
|
| 3539 |
+
stack: [B, H, slots, ds] β previous layer's stack state,
|
| 3540 |
+
or None (initialised to zeros for layer 0).
|
| 3541 |
+
mask: [B, H, slots] β validity mask for stack,
|
| 3542 |
+
or None (initialised to zeros for layer 0).
|
| 3543 |
+
analysis: StackMemoryAnalysis container; populated when
|
| 3544 |
+
model is in eval + analysis mode.
|
| 3545 |
+
|
| 3546 |
+
Returns:
|
| 3547 |
+
(output, new_stack, new_mask)
|
| 3548 |
+
output [B, S, D]
|
| 3549 |
+
new_stack [B, H, slots, ds] β stack at final sequence position
|
| 3550 |
+
new_mask [B, H, slots]
|
| 3551 |
+
"""
|
| 3552 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 3553 |
+
device = hidden_states.device
|
| 3554 |
+
|
| 3555 |
+
# Capture incoming stack for analysis before it is updated
|
| 3556 |
+
if analysis is not None:
|
| 3557 |
+
analysis.stack_in = stack.detach() if stack is not None else None
|
| 3558 |
+
|
| 3559 |
+
# Initialise empty stack / mask for layer 0
|
| 3560 |
+
if stack is None:
|
| 3561 |
+
stack = torch.zeros(
|
| 3562 |
+
batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
|
| 3563 |
+
device=device, dtype=hidden_states.dtype,
|
| 3564 |
+
)
|
| 3565 |
+
if mask is None:
|
| 3566 |
+
mask = torch.zeros(
|
| 3567 |
+
batch_size, self.num_stack_heads, self.stack_slots,
|
| 3568 |
+
device=device, dtype=hidden_states.dtype,
|
| 3569 |
+
)
|
| 3570 |
+
|
| 3571 |
+
# 1. Project down
|
| 3572 |
+
h_proj = self.down_proj(hidden_states) # [B,S,stack_d_model]
|
| 3573 |
+
|
| 3574 |
+
# 2. Action probabilities
|
| 3575 |
+
action_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
|
| 3576 |
+
actions = F.softmax(
|
| 3577 |
+
action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), dim=-1
|
| 3578 |
+
) # [B,S,H,3]
|
| 3579 |
+
|
| 3580 |
+
# 3. Values to push
|
| 3581 |
+
k_values = h_proj.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
|
| 3582 |
+
|
| 3583 |
+
# 4. Vectorized stack update
|
| 3584 |
+
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 3585 |
+
# new_stack: [B,S,H,slots,ds], new_mask: [B,S,H,slots]
|
| 3586 |
+
|
| 3587 |
+
# 5. Global read (query-over-stack attention, paper Β§3.1)
|
| 3588 |
+
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,S,H,slots]
|
| 3589 |
+
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
|
| 3590 |
+
memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 3591 |
+
# memory_out: [B,S,H,ds] β [B,S,stack_d_model]
|
| 3592 |
+
memory_out = memory_out.view(batch_size, seq_len, self.stack_d_model)
|
| 3593 |
+
|
| 3594 |
+
# 6. Project back up
|
| 3595 |
+
memory_out_proj = self.up_proj(memory_out) # [B,S,D]
|
| 3596 |
+
|
| 3597 |
+
# 7. Residual
|
| 3598 |
+
output = memory_out_proj * self.res_weight + hidden_states
|
| 3599 |
+
|
| 3600 |
+
# 8. Update generation cache (no-op during training)
|
| 3601 |
+
if self.enable_cache:
|
| 3602 |
+
self._update_cache(k_values.detach(), actions.detach())
|
| 3603 |
+
|
| 3604 |
+
# Populate analysis fields
|
| 3605 |
+
if analysis is not None:
|
| 3606 |
+
analysis.action_probs = actions.detach()
|
| 3607 |
+
analysis.stack_out = new_stack[:, -1].detach()
|
| 3608 |
+
analysis.mask_out = new_mask[:, -1].detach()
|
| 3609 |
+
analysis.gate_weights = gate_weights.detach()
|
| 3610 |
+
analysis.memory_output = memory_out.detach()
|
| 3611 |
+
analysis.residual_scale = self.res_weight.item()
|
| 3612 |
+
|
| 3613 |
+
# Return output + last-position stack state for next layer
|
| 3614 |
+
return output, new_stack[:, -1], new_mask[:, -1]
|
| 3615 |
+
|
| 3616 |
+
# ββ Autoregressive single-token forward ββββββββββββββββββββββββββββββ
|
| 3617 |
+
|
| 3618 |
+
def step(
|
| 3619 |
+
self,
|
| 3620 |
+
hidden_state: torch.Tensor, # [B, D]
|
| 3621 |
+
stack: torch.Tensor, # [B, H, slots, ds]
|
| 3622 |
+
mask: torch.Tensor, # [B, H, slots]
|
| 3623 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 3624 |
+
"""
|
| 3625 |
+
Single-token forward for autoregressive generation.
|
| 3626 |
+
|
| 3627 |
+
When enable_cache=False (simple path used by NeoLLM generation):
|
| 3628 |
+
Calls forward() with a length-1 sequence and unpacks the result.
|
| 3629 |
+
The stack state passed in carries all history from previous tokens
|
| 3630 |
+
(propagated by NeoLLMModel.forward across generation steps).
|
| 3631 |
+
|
| 3632 |
+
When enable_cache=True (full-history reconstruction path):
|
| 3633 |
+
Concatenates the current token with cached previous-token values
|
| 3634 |
+
and replays the full vectorized update, extracting only the last
|
| 3635 |
+
position. This gives a more accurate stack that sees full history
|
| 3636 |
+
at the cost of O(T) computation per step.
|
| 3637 |
+
|
| 3638 |
+
Returns:
|
| 3639 |
+
(output, new_stack, new_mask)
|
| 3640 |
+
output [B, D]
|
| 3641 |
+
new_stack [B, H, slots, ds]
|
| 3642 |
+
new_mask [B, H, slots]
|
| 3643 |
+
"""
|
| 3644 |
+
if not self.enable_cache:
|
| 3645 |
+
# Simple path: forward with seq_len=1, squeeze the sequence dim
|
| 3646 |
+
out, new_stack, new_mask = self.forward(
|
| 3647 |
+
hidden_state.unsqueeze(1), stack, mask
|
| 3648 |
+
)
|
| 3649 |
+
return out.squeeze(1), new_stack, new_mask
|
| 3650 |
+
|
| 3651 |
+
batch_size = hidden_state.shape[0]
|
| 3652 |
+
|
| 3653 |
+
# Compute features for the current token
|
| 3654 |
+
h_proj = self.down_proj(hidden_state) # [B, stack_d_model]
|
| 3655 |
+
a_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
|
| 3656 |
+
cur_act = F.softmax(
|
| 3657 |
+
a_logits.view(batch_size, 1, self.num_stack_heads, 3), dim=-1
|
| 3658 |
+
) # [B,1,H,3]
|
| 3659 |
+
cur_k = h_proj.view(batch_size, 1, self.num_stack_heads, self.head_dim)
|
| 3660 |
+
|
| 3661 |
+
# Prepend cached history (all previous tokens in this generation)
|
| 3662 |
+
if self.cache_position > 0:
|
| 3663 |
+
k_values = torch.cat([self.k_cache[:batch_size, :self.cache_position], cur_k], dim=1)
|
| 3664 |
+
actions = torch.cat([self.action_cache[:batch_size, :self.cache_position], cur_act], dim=1)
|
| 3665 |
+
else:
|
| 3666 |
+
k_values = cur_k
|
| 3667 |
+
actions = cur_act
|
| 3668 |
+
|
| 3669 |
+
# Full vectorized update over history + current token; take last position
|
| 3670 |
+
new_stack_seq, new_mask_seq = self._vectorized_update(stack, mask, actions, k_values)
|
| 3671 |
+
new_stack = new_stack_seq[:, -1] # [B,H,slots,ds]
|
| 3672 |
+
new_mask = new_mask_seq[:, -1] # [B,H,slots]
|
| 3673 |
+
|
| 3674 |
+
# Global read on the new stack state
|
| 3675 |
+
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,H,slots]
|
| 3676 |
+
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
|
| 3677 |
+
memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
|
| 3678 |
+
memory_out = memory_out.view(batch_size, self.stack_d_model)
|
| 3679 |
+
|
| 3680 |
+
memory_out_proj = self.up_proj(memory_out) # [B,D]
|
| 3681 |
+
output = memory_out_proj * self.res_weight + hidden_state
|
| 3682 |
+
|
| 3683 |
+
self._update_cache(cur_k, cur_act)
|
| 3684 |
+
return output, new_stack, new_mask
|
| 3685 |
+
|
| 3686 |
+
|
| 3687 |
+
@dataclass
|
| 3688 |
+
class LAuReLAnalysis:
|
| 3689 |
+
"""
|
| 3690 |
+
Internals of one LAuReL residual connection forward pass.
|
| 3691 |
+
Only populated when use_laurel=True AND model is in eval + analysis mode.
|
| 3692 |
+
Instantiated twice per layer: once for the attention residual, once for MLP.
|
| 3693 |
+
|
| 3694 |
+
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 3695 |
+
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 3696 |
+
|
| 3697 |
+
Math (combined RW+LR, both sub-variants active):
|
| 3698 |
+
|
| 3699 |
+
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 3700 |
+
|
| 3701 |
+
where [Ξ±, Ξ²] = softmax([a, b]), a,b β β learnable (RW component),
|
| 3702 |
+
B β β^{rΓD} column-orthogonal init, A β β^{DΓr} zero init (LR component).
|
| 3703 |
+
At step 0: A=0 β lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
|
| 3704 |
+
or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
|
| 3705 |
+
|
| 3706 |
+
Fields:
|
| 3707 |
+
alpha_rw: softmax(a) β weight on f(x_i). [scalar float]
|
| 3708 |
+
None when use_laurel_rw=False.
|
| 3709 |
+
beta_rw: softmax(b) β weight on g(x_i). [scalar float]
|
| 3710 |
+
None when use_laurel_rw=False.
|
| 3711 |
+
lr_term: AΒ·(BΒ·x_res) β the low-rank residual augmentation.
|
| 3712 |
+
Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
|
| 3713 |
+
output: Final combined tensor before GPAS. Shape [B, S, D].
|
| 3714 |
+
"""
|
| 3715 |
+
alpha_rw: Optional[float] = None # softmax weight on f(x)
|
| 3716 |
+
beta_rw: Optional[float] = None # softmax weight on g(x)
|
| 3717 |
+
lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
|
| 3718 |
+
output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
|
| 3719 |
+
|
| 3720 |
+
|
| 3721 |
+
class LAuReLLayer(nn.Module):
|
| 3722 |
+
"""
|
| 3723 |
+
LAuReL: Learned Augmented Residual Layer.
|
| 3724 |
+
|
| 3725 |
+
A lightweight replacement for the canonical residual connection
|
| 3726 |
+
that learns to blend the nonlinear sub-layer output f(x) with a
|
| 3727 |
+
richer linear function of the residual x, optionally augmented by a
|
| 3728 |
+
low-rank transformation.
|
| 3729 |
+
|
| 3730 |
+
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 3731 |
+
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 3732 |
+
|
| 3733 |
+
ββ Sub-variants ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3734 |
+
Controlled by config flags; any combination is valid:
|
| 3735 |
+
|
| 3736 |
+
**RW only** (use_laurel_rw=True, use_laurel_lr=False):
|
| 3737 |
+
|
| 3738 |
+
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· x_i
|
| 3739 |
+
[Ξ±, Ξ²] = softmax([a, b]), a,b β β (2 params)
|
| 3740 |
+
|
| 3741 |
+
**LR only** (use_laurel_rw=False, use_laurel_lr=True):
|
| 3742 |
+
|
| 3743 |
+
x_{i+1} = f(x_i) + AΒ·(BΒ·x_i) + x_i
|
| 3744 |
+
B β β^{rΓD} column-orthogonal init (down-projection)
|
| 3745 |
+
A β β^{DΓr} zero init (up-projection)
|
| 3746 |
+
Params: 2Β·rΒ·D per layer.
|
| 3747 |
+
|
| 3748 |
+
**RW + LR** (both True, paper recommendation):
|
| 3749 |
+
|
| 3750 |
+
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 3751 |
+
|
| 3752 |
+
ββ Initialisation ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3753 |
+
RW: raw logits [a, b] = [0, 0] β Ξ±=Ξ²=0.5 at step 0.
|
| 3754 |
+
LR: A (up) = zeros β lr_term = 0 at step 0 β pure residual at init.
|
| 3755 |
+
This ensures the model starts as a standard residual and smoothly
|
| 3756 |
+
diverges as the gates and low-rank matrices are trained.
|
| 3757 |
+
|
| 3758 |
+
ββ Integration in NeoLLM βββββββββββββββββββββββββββββββββββββββββββ
|
| 3759 |
+
Applied immediately before GPAS at both residual sums per layer:
|
| 3760 |
+
|
| 3761 |
+
h_tilde = GPAS( LAuReL(attn_out, residual_attn) )
|
| 3762 |
+
output = GPAS( LAuReL(delta_m, residual_mlp) )
|
| 3763 |
+
|
| 3764 |
+
GPAS then applies its stop-gradient scaling on the combined stream,
|
| 3765 |
+
preserving gradient magnitudes across the depth of the network.
|
| 3766 |
+
The two techniques are structurally orthogonal: LAuReL controls the
|
| 3767 |
+
*mixing ratio* of f(x) and x at each residual junction; GPAS
|
| 3768 |
+
controls the *magnitude* of the combined stream with a learned gate
|
| 3769 |
+
and a stop-gradient operator that prevents gradient vanishing.
|
| 3770 |
+
|
| 3771 |
+
Args:
|
| 3772 |
+
config: NeoLLMConfig. Reads use_laurel_rw, use_laurel_lr,
|
| 3773 |
+
laurel_lr_rank, hidden_size.
|
| 3774 |
+
"""
|
| 3775 |
+
|
| 3776 |
+
def __init__(self, config: NeoLLMConfig):
|
| 3777 |
+
super().__init__()
|
| 3778 |
+
self.use_rw = getattr(config, "use_laurel_rw", True)
|
| 3779 |
+
self.use_lr = getattr(config, "use_laurel_lr", True)
|
| 3780 |
+
D = config.hidden_size
|
| 3781 |
+
r = getattr(config, "laurel_lr_rank", 32)
|
| 3782 |
+
|
| 3783 |
+
if self.use_rw:
|
| 3784 |
+
# Raw logits for softmax([Ξ±, Ξ²]).
|
| 3785 |
+
# Stored as a single 2-vector so softmax is one op.
|
| 3786 |
+
# Init to zero β Ξ±=Ξ²=0.5 at step 0.
|
| 3787 |
+
self.rw_logits = nn.Parameter(torch.zeros(2))
|
| 3788 |
+
|
| 3789 |
+
if self.use_lr:
|
| 3790 |
+
# down: B β β^{rΓD}, column-orthogonal init (paper Β§3.3 LLM recommendation)
|
| 3791 |
+
# up: A β β^{DΓr}, zero init β lr_term=0 at step 0 (LoRA-style)
|
| 3792 |
+
self.lr_down = nn.Linear(D, r, bias=False)
|
| 3793 |
+
self.lr_up = nn.Linear(r, D, bias=False)
|
| 3794 |
+
|
| 3795 |
+
def forward(
|
| 3796 |
+
self,
|
| 3797 |
+
f_out: torch.Tensor, # output of f(x): attn or MLP [B,S,D]
|
| 3798 |
+
x_res: torch.Tensor, # residual (skip connection) [B,S,D]
|
| 3799 |
+
analysis: Optional[LAuReLAnalysis] = None,
|
| 3800 |
+
) -> torch.Tensor:
|
| 3801 |
+
"""
|
| 3802 |
+
Args:
|
| 3803 |
+
f_out: Output of f(x) β attention output or MLP delta.
|
| 3804 |
+
x_res: Residual tensor β accumulated hidden state.
|
| 3805 |
+
analysis: Optional analysis container; populated in eval+analysis mode.
|
| 3806 |
+
|
| 3807 |
+
Returns:
|
| 3808 |
+
Combined tensor [B, S, D] to be fed into GPAS.
|
| 3809 |
+
"""
|
| 3810 |
+
# ββ LR component: AΒ·(BΒ·x_res) ββββββββββββββββββββββββββββββββββββ
|
| 3811 |
+
lr_term = None
|
| 3812 |
+
if self.use_lr:
|
| 3813 |
+
lr_term = self.lr_up(self.lr_down(x_res)) # [B,S,D]
|
| 3814 |
+
g_res = lr_term + x_res # enriched residual
|
| 3815 |
+
else:
|
| 3816 |
+
g_res = x_res
|
| 3817 |
+
|
| 3818 |
+
# ββ RW component: Ξ±Β·f + Ξ²Β·g ββββββββββββββββββββββββββββββββββββββ
|
| 3819 |
+
if self.use_rw:
|
| 3820 |
+
weights = torch.softmax(self.rw_logits, dim=0) # [2]
|
| 3821 |
+
alpha = weights[0]
|
| 3822 |
+
beta = weights[1]
|
| 3823 |
+
out = alpha * f_out + beta * g_res
|
| 3824 |
+
else:
|
| 3825 |
+
# LR only: standard sum with enriched residual
|
| 3826 |
+
out = f_out + g_res
|
| 3827 |
+
|
| 3828 |
+
if analysis is not None:
|
| 3829 |
+
if self.use_rw:
|
| 3830 |
+
analysis.alpha_rw = alpha.item()
|
| 3831 |
+
analysis.beta_rw = beta.item()
|
| 3832 |
+
if self.use_lr:
|
| 3833 |
+
analysis.lr_term = lr_term.detach()
|
| 3834 |
+
analysis.output = out.detach()
|
| 3835 |
+
|
| 3836 |
+
return out
|
| 3837 |
+
|
| 3838 |
+
|
| 3839 |
+
# ==================== GATED DELTA NET (LINEAR ATTENTION) ====================
|
| 3840 |
+
# Active when use_linear_attention=True. Replaces NeoLLMAttention every
|
| 3841 |
+
# `linear_attention_every_n` layers (pattern 0-indexed: layers 2, 5, 8 β¦).
|
| 3842 |
+
#
|
| 3843 |
+
# References:
|
| 3844 |
+
# Yang et al. (2024). "Gated Delta Networks." arXiv:2412.06464.
|
| 3845 |
+
# Li et al. (2026). "REPO." arXiv:2512.14391.
|
| 3846 |
+
|
| 3847 |
+
|
| 3848 |
+
def _apply_mask_to_padding_states(
|
| 3849 |
+
hidden_states: torch.Tensor,
|
| 3850 |
+
attention_mask: Optional[torch.Tensor],
|
| 3851 |
+
) -> torch.Tensor:
|
| 3852 |
+
if (
|
| 3853 |
+
attention_mask is not None
|
| 3854 |
+
and attention_mask.shape[1] > 1
|
| 3855 |
+
and attention_mask.shape[0] > 1
|
| 3856 |
+
):
|
| 3857 |
+
hidden_states = (
|
| 3858 |
+
hidden_states * attention_mask[:, :, None]
|
| 3859 |
+
).to(hidden_states.dtype)
|
| 3860 |
+
return hidden_states
|
| 3861 |
+
|
| 3862 |
+
|
| 3863 |
+
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
|
| 3864 |
+
return x / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
| 3865 |
+
|
| 3866 |
+
|
| 3867 |
+
def _torch_causal_conv1d_update(
|
| 3868 |
+
hidden_states, conv_state, weight, bias=None, activation=None
|
| 3869 |
+
):
|
| 3870 |
+
_, hidden_size, seq_len = hidden_states.shape
|
| 3871 |
+
state_len = conv_state.shape[-1]
|
| 3872 |
+
combined = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
|
| 3873 |
+
conv_state.copy_(combined[:, :, -state_len:])
|
| 3874 |
+
out = F.conv1d(combined, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
|
| 3875 |
+
return F.silu(out[:, :, -seq_len:]).to(hidden_states.dtype)
|
| 3876 |
+
|
| 3877 |
+
|
| 3878 |
+
def _torch_chunk_gated_delta_rule(
|
| 3879 |
+
query, key, value, g, beta,
|
| 3880 |
+
chunk_size=64, initial_state=None, output_final_state=False,
|
| 3881 |
+
use_qk_l2norm_in_kernel=False,
|
| 3882 |
+
):
|
| 3883 |
+
initial_dtype = query.dtype
|
| 3884 |
+
if use_qk_l2norm_in_kernel:
|
| 3885 |
+
query, key = _l2norm(query), _l2norm(key)
|
| 3886 |
+
query, key, value, beta, g = [
|
| 3887 |
+
x.transpose(1, 2).contiguous().to(torch.float32)
|
| 3888 |
+
for x in (query, key, value, beta, g)
|
| 3889 |
+
]
|
| 3890 |
+
bs, seq, nh, kdim = key.shape
|
| 3891 |
+
vdim = value.shape[-1]
|
| 3892 |
+
pad = (chunk_size - nh % chunk_size) % chunk_size
|
| 3893 |
+
for t in (query, key, value):
|
| 3894 |
+
t = F.pad(t, (0, 0, 0, pad))
|
| 3895 |
+
query = F.pad(query, (0, 0, 0, pad))
|
| 3896 |
+
key = F.pad(key, (0, 0, 0, pad))
|
| 3897 |
+
value = F.pad(value, (0, 0, 0, pad))
|
| 3898 |
+
beta = F.pad(beta, (0, pad))
|
| 3899 |
+
g = F.pad(g, (0, pad))
|
| 3900 |
+
tot = nh + pad
|
| 3901 |
+
scale = query.shape[-1] ** -0.5
|
| 3902 |
+
query = query * scale
|
| 3903 |
+
vb = value * beta.unsqueeze(-1)
|
| 3904 |
+
kb = key * beta.unsqueeze(-1)
|
| 3905 |
+
query, key, value, kb, vb = [
|
| 3906 |
+
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
| 3907 |
+
for x in (query, key, value, kb, vb)
|
| 3908 |
+
]
|
| 3909 |
+
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
| 3910 |
+
triu = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 0)
|
| 3911 |
+
g = g.cumsum(-1)
|
| 3912 |
+
dm = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp()).tril()
|
| 3913 |
+
attn = -((kb @ key.transpose(-1, -2)) * dm).masked_fill(triu, 0)
|
| 3914 |
+
for i in range(1, chunk_size):
|
| 3915 |
+
r = attn[..., i, :i].clone(); s = attn[..., :i, :i].clone()
|
| 3916 |
+
attn[..., i, :i] = r + (r.unsqueeze(-1) * s).sum(-2)
|
| 3917 |
+
eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
| 3918 |
+
attn = attn + eye
|
| 3919 |
+
value = attn @ vb
|
| 3920 |
+
kcd = attn @ (kb * g.exp().unsqueeze(-1))
|
| 3921 |
+
st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
|
| 3922 |
+
out = torch.zeros_like(value)
|
| 3923 |
+
triu2 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1)
|
| 3924 |
+
for i in range(tot // chunk_size):
|
| 3925 |
+
qi, ki, vi = query[:,:,i], key[:,:,i], value[:,:,i]
|
| 3926 |
+
a = (qi @ ki.transpose(-1,-2) * dm[:,:,i]).masked_fill_(triu2, 0)
|
| 3927 |
+
vp = kcd[:,:,i] @ st
|
| 3928 |
+
vn = vi - vp
|
| 3929 |
+
out[:,:,i] = (qi * g[:,:,i,:,None].exp()) @ st + a @ vn
|
| 3930 |
+
st = st * g[:,:,i,-1,None,None].exp() + (ki * (g[:,:,i,-1,None]-g[:,:,i]).exp()[...,None]).transpose(-1,-2) @ vn
|
| 3931 |
+
if not output_final_state: st = None
|
| 3932 |
+
out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1])[:,:,:nh]
|
| 3933 |
+
return out.transpose(1,2).contiguous().to(initial_dtype), st
|
| 3934 |
+
|
| 3935 |
+
|
| 3936 |
+
def _torch_recurrent_gated_delta_rule(
|
| 3937 |
+
query, key, value, g, beta, initial_state, output_final_state,
|
| 3938 |
+
use_qk_l2norm_in_kernel=False,
|
| 3939 |
+
):
|
| 3940 |
+
initial_dtype = query.dtype
|
| 3941 |
+
if use_qk_l2norm_in_kernel:
|
| 3942 |
+
query, key = _l2norm(query), _l2norm(key)
|
| 3943 |
+
query, key, value, beta, g = [
|
| 3944 |
+
x.transpose(1,2).contiguous().to(torch.float32)
|
| 3945 |
+
for x in (query, key, value, beta, g)
|
| 3946 |
+
]
|
| 3947 |
+
bs, seq, nh, kdim = key.shape
|
| 3948 |
+
vdim = value.shape[-1]
|
| 3949 |
+
query = query * (query.shape[-1] ** -0.5)
|
| 3950 |
+
out = torch.zeros(bs, seq, nh, vdim, dtype=value.dtype, device=value.device)
|
| 3951 |
+
st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
|
| 3952 |
+
for i in range(nh):
|
| 3953 |
+
qt, kt, vt = query[:,:,i], key[:,:,i], value[:,:,i]
|
| 3954 |
+
gt, bt = g[:,:,i].exp().unsqueeze(-1).unsqueeze(-1), beta[:,:,i].unsqueeze(-1)
|
| 3955 |
+
st = st * gt
|
| 3956 |
+
delta = (vt - (st * kt.unsqueeze(-1)).sum(-2)) * bt
|
| 3957 |
+
st = st + kt.unsqueeze(-1) * delta.unsqueeze(-2)
|
| 3958 |
+
out[:,:,i] = (st * qt.unsqueeze(-1)).sum(-2)
|
| 3959 |
+
if not output_final_state: st = None
|
| 3960 |
+
return out.transpose(1,2).contiguous().to(initial_dtype), st
|
| 3961 |
+
|
| 3962 |
+
|
| 3963 |
+
class _NeoLLMRMSNormGated(nn.Module):
|
| 3964 |
+
"""Gated RMSNorm fallback when FLA unavailable."""
|
| 3965 |
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
| 3966 |
+
super().__init__()
|
| 3967 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 3968 |
+
self.eps = eps
|
| 3969 |
+
def forward(self, x, gate):
|
| 3970 |
+
dtype = x.dtype
|
| 3971 |
+
x = x.float()
|
| 3972 |
+
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 3973 |
+
return (self.weight * x.to(dtype) * F.silu(gate.float())).to(dtype)
|
| 3974 |
+
|
| 3975 |
+
|
| 3976 |
+
class NeoLLMGatedDeltaNet(nn.Module):
|
| 3977 |
+
"""
|
| 3978 |
+
GatedDeltaNet linear attention with FANformer integration.
|
| 3979 |
+
|
| 3980 |
+
Replaces NeoLLMAttention on every ``linear_attention_every_n``-th layer
|
| 3981 |
+
(0-indexed: layers 2, 5, 8 β¦ for every_n=3).
|
| 3982 |
+
|
| 3983 |
+
REPO (use_repo=True AND use_repo_in_linear_attn=True):
|
| 3984 |
+
Applies continuous per-head positions to Q and K via _apply_repo_rope,
|
| 3985 |
+
matching the full-attention REPO path identically.
|
| 3986 |
+
|
| 3987 |
+
Without REPO the gated delta rule operates without explicit positional
|
| 3988 |
+
encoding (its recurrent state is implicitly position-aware).
|
| 3989 |
+
|
| 3990 |
+
References:
|
| 3991 |
+
Yang et al. (2024). arXiv:2412.06464.
|
| 3992 |
+
Li et al. (2026). arXiv:2512.14391.
|
| 3993 |
+
"""
|
| 3994 |
+
|
| 3995 |
+
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 3996 |
+
super().__init__()
|
| 3997 |
+
self.hidden_size = config.hidden_size
|
| 3998 |
+
self.num_v_heads = config.linear_num_value_heads
|
| 3999 |
+
self.num_k_heads = config.linear_num_key_heads
|
| 4000 |
+
self.head_k_dim = config.linear_key_head_dim
|
| 4001 |
+
self.head_v_dim = config.linear_value_head_dim
|
| 4002 |
+
self.key_dim = self.head_k_dim * self.num_k_heads
|
| 4003 |
+
self.value_dim = self.head_v_dim * self.num_v_heads
|
| 4004 |
+
self.conv_kernel_size = config.linear_conv_kernel_dim
|
| 4005 |
+
self.layer_idx = layer_idx
|
| 4006 |
+
|
| 4007 |
+
# ββ FANformer (same ratio as full-attention layers) ββββββββββββββββ
|
| 4008 |
+
self.fan_layer = FANLayer(
|
| 4009 |
+
hidden_size=config.hidden_size,
|
| 4010 |
+
fan_ratio=getattr(config, "fan_ratio", 0.125),
|
| 4011 |
+
)
|
| 4012 |
+
_fan_dim = config.hidden_size + int(
|
| 4013 |
+
config.hidden_size * getattr(config, "fan_ratio", 0.125)
|
| 4014 |
+
)
|
| 4015 |
+
|
| 4016 |
+
# ββ Causal conv1d on concatenated Q/K/V ββββββββββββββββββββββββββ
|
| 4017 |
+
self.conv_dim = self.key_dim * 2 + self.value_dim
|
| 4018 |
+
self.conv1d = nn.Conv1d(
|
| 4019 |
+
self.conv_dim, self.conv_dim, bias=False,
|
| 4020 |
+
kernel_size=self.conv_kernel_size,
|
| 4021 |
+
groups=self.conv_dim,
|
| 4022 |
+
padding=self.conv_kernel_size - 1,
|
| 4023 |
+
)
|
| 4024 |
+
|
| 4025 |
+
# ββ QKVz + ba projections (all from FAN-transformed features) βββββ
|
| 4026 |
+
_ratio = self.num_v_heads // self.num_k_heads
|
| 4027 |
+
self.in_proj_qkvz = nn.Linear(
|
| 4028 |
+
_fan_dim, self.key_dim * 2 + self.value_dim * 2, bias=False
|
| 4029 |
+
)
|
| 4030 |
+
self.in_proj_ba = nn.Linear(
|
| 4031 |
+
_fan_dim, self.num_v_heads * 2, bias=False
|
| 4032 |
+
)
|
| 4033 |
+
|
| 4034 |
+
# ββ Delta-rule gating parameters ββββββββββββββββββββββββββββββββββ
|
| 4035 |
+
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
|
| 4036 |
+
A = torch.empty(self.num_v_heads).uniform_(0, 16)
|
| 4037 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 4038 |
+
|
| 4039 |
+
# ββ Output normalisation βββββββββββββββββββββββββββββββββββββββββββ
|
| 4040 |
+
_NormCls = FusedRMSNormGated if FusedRMSNormGated is not None else _NeoLLMRMSNormGated
|
| 4041 |
+
_norm_kw = (
|
| 4042 |
+
dict(activation="silu",
|
| 4043 |
+
device=torch.cuda.current_device(),
|
| 4044 |
+
dtype=getattr(config, "dtype", None) or torch.get_default_dtype())
|
| 4045 |
+
if FusedRMSNormGated is not None else {}
|
| 4046 |
+
)
|
| 4047 |
+
self.norm = _NormCls(self.head_v_dim, eps=config.rms_norm_eps, **_norm_kw)
|
| 4048 |
+
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
| 4049 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 4050 |
+
|
| 4051 |
+
# ββ Kernel dispatch (fast β fallback) βββββββββββββββββββββββββββββ
|
| 4052 |
+
self._conv1d_fn = causal_conv1d_fn # None if not installed
|
| 4053 |
+
self._chunk_fn = (chunk_gated_delta_rule
|
| 4054 |
+
if chunk_gated_delta_rule is not None
|
| 4055 |
+
else _torch_chunk_gated_delta_rule)
|
| 4056 |
+
self._recur_fn = (fused_recurrent_gated_delta_rule
|
| 4057 |
+
if fused_recurrent_gated_delta_rule is not None
|
| 4058 |
+
else _torch_recurrent_gated_delta_rule)
|
| 4059 |
+
|
| 4060 |
+
if not is_linear_attn_fast_path:
|
| 4061 |
+
logger.warning_once(
|
| 4062 |
+
"NeoLLMGatedDeltaNet: causal_conv1d / flash-linear-attention "
|
| 4063 |
+
"not installed β using pure-PyTorch fallbacks. "
|
| 4064 |
+
"Install both packages for full performance."
|
| 4065 |
+
)
|
| 4066 |
+
|
| 4067 |
+
# ββ REPO: continuous per-head positions on Q and K βββββββββββββββββ
|
| 4068 |
+
# Controlled by use_repo AND use_repo_in_linear_attn flags.
|
| 4069 |
+
# Only active for layers at or above repo_start_layer.
|
| 4070 |
+
self.use_repo = (
|
| 4071 |
+
getattr(config, "use_repo", False)
|
| 4072 |
+
and getattr(config, "use_repo_in_linear_attn", False)
|
| 4073 |
+
and layer_idx >= getattr(config, "repo_start_layer",
|
| 4074 |
+
config.num_hidden_layers // 3)
|
| 4075 |
+
)
|
| 4076 |
+
if self.use_repo:
|
| 4077 |
+
_d_p = getattr(config, "repo_d_p", config.hidden_size // 8)
|
| 4078 |
+
self.repo_module = REPOModule(
|
| 4079 |
+
hidden_size=config.hidden_size,
|
| 4080 |
+
d_p=_d_p,
|
| 4081 |
+
num_heads=self.num_v_heads,
|
| 4082 |
+
)
|
| 4083 |
+
else:
|
| 4084 |
+
self.repo_module = None
|
| 4085 |
+
|
| 4086 |
+
def _fix_qkvz(
|
| 4087 |
+
self,
|
| 4088 |
+
mixed_qkvz: torch.Tensor,
|
| 4089 |
+
mixed_ba: torch.Tensor,
|
| 4090 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 4091 |
+
"""Split fused projection into (q, k, v, z, b, a)."""
|
| 4092 |
+
ratio = self.num_v_heads // self.num_k_heads
|
| 4093 |
+
mixed_qkvz = mixed_qkvz.view(
|
| 4094 |
+
*mixed_qkvz.shape[:-1],
|
| 4095 |
+
self.num_k_heads,
|
| 4096 |
+
2 * self.head_k_dim + 2 * ratio * self.head_v_dim,
|
| 4097 |
+
)
|
| 4098 |
+
mixed_ba = mixed_ba.view(
|
| 4099 |
+
*mixed_ba.shape[:-1],
|
| 4100 |
+
self.num_k_heads,
|
| 4101 |
+
2 * ratio,
|
| 4102 |
+
)
|
| 4103 |
+
q, k, v, z = torch.split(
|
| 4104 |
+
mixed_qkvz,
|
| 4105 |
+
[self.head_k_dim, self.head_k_dim,
|
| 4106 |
+
ratio * self.head_v_dim, ratio * self.head_v_dim],
|
| 4107 |
+
dim=3,
|
| 4108 |
+
)
|
| 4109 |
+
b, a = torch.split(mixed_ba, ratio, dim=3)
|
| 4110 |
+
v = v.reshape(v.shape[0], v.shape[1], -1, self.head_v_dim)
|
| 4111 |
+
z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim)
|
| 4112 |
+
b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads)
|
| 4113 |
+
a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads)
|
| 4114 |
+
return q, k, v, z, b, a
|
| 4115 |
+
|
| 4116 |
+
def forward(
|
| 4117 |
+
self,
|
| 4118 |
+
hidden_states: torch.Tensor,
|
| 4119 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 4120 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 4121 |
+
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 4122 |
+
) -> torch.Tensor:
|
| 4123 |
+
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 4124 |
+
B, S, _ = hidden_states.shape
|
| 4125 |
+
|
| 4126 |
+
# ββ FANformer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4127 |
+
h_fan = self.fan_layer(hidden_states)
|
| 4128 |
+
|
| 4129 |
+
# ββ QKVz and ba projections βββββββββββββββββββββββββββββββββββββββ
|
| 4130 |
+
q, k, v, z, b, a = self._fix_qkvz(
|
| 4131 |
+
self.in_proj_qkvz(h_fan), self.in_proj_ba(h_fan)
|
| 4132 |
+
)
|
| 4133 |
+
|
| 4134 |
+
# ββ Causal conv1d on flattened Q/K/V βββββββββββββββββββββββββββββ
|
| 4135 |
+
qkv = torch.cat(
|
| 4136 |
+
[q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], dim=-1
|
| 4137 |
+
).transpose(1, 2) # [B, conv_dim, S]
|
| 4138 |
+
|
| 4139 |
+
if self._conv1d_fn is not None:
|
| 4140 |
+
qkv = self._conv1d_fn(
|
| 4141 |
+
x=qkv, weight=self.conv1d.weight.squeeze(1),
|
| 4142 |
+
bias=self.conv1d.bias, activation="silu", seq_idx=None,
|
| 4143 |
+
)
|
| 4144 |
+
else:
|
| 4145 |
+
qkv = F.silu(self.conv1d(qkv)[:, :, :S])
|
| 4146 |
+
qkv = qkv.transpose(1, 2) # [B, S, conv_dim]
|
| 4147 |
+
|
| 4148 |
+
q_f, k_f, v_f = torch.split(
|
| 4149 |
+
qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1
|
| 4150 |
+
)
|
| 4151 |
+
q = q_f.reshape(B, S, -1, self.head_k_dim)
|
| 4152 |
+
k = k_f.reshape(B, S, -1, self.head_k_dim)
|
| 4153 |
+
v = v_f.reshape(B, S, -1, self.head_v_dim)
|
| 4154 |
+
|
| 4155 |
+
# ββ REPO: continuous per-head positions βββββββββββββββββββββββββββ
|
| 4156 |
+
# Transpose to [B, H, S, dk] for _apply_repo_rope, then back.
|
| 4157 |
+
if self.use_repo and self.repo_module is not None and repo_rope_args is not None:
|
| 4158 |
+
inv_freq, attn_scaling = repo_rope_args
|
| 4159 |
+
z_pos = self.repo_module(hidden_states) # [B, H, S]
|
| 4160 |
+
q_t, k_t = q.transpose(1, 2), k.transpose(1, 2)
|
| 4161 |
+
q_t, k_t = _apply_repo_rope(q_t, k_t, z_pos, inv_freq, attn_scaling)
|
| 4162 |
+
q, k = q_t.transpose(1, 2), k_t.transpose(1, 2)
|
| 4163 |
+
|
| 4164 |
+
# ββ GQA-like head expansion ββββββββββββββββββββββββββββββββββββββββ
|
| 4165 |
+
ratio = self.num_v_heads // self.num_k_heads
|
| 4166 |
+
if ratio > 1:
|
| 4167 |
+
q = q.repeat_interleave(ratio, dim=2)
|
| 4168 |
+
k = k.repeat_interleave(ratio, dim=2)
|
| 4169 |
+
|
| 4170 |
+
beta = b.sigmoid()
|
| 4171 |
+
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
| 4172 |
+
|
| 4173 |
+
# ββ Chunk gated delta rule (fused or fallback) ββββββββββββββββββββ
|
| 4174 |
+
core_out, _ = self._chunk_fn(
|
| 4175 |
+
q, k, v, g=g, beta=beta,
|
| 4176 |
+
initial_state=None, output_final_state=False,
|
| 4177 |
+
use_qk_l2norm_in_kernel=True,
|
| 4178 |
+
)
|
| 4179 |
+
|
| 4180 |
+
# ββ Gated RMSNorm + output projection βββββββββββββββββββββββββββββ
|
| 4181 |
+
z_shape = z.shape
|
| 4182 |
+
core_out = core_out.reshape(-1, core_out.shape[-1])
|
| 4183 |
+
core_out = self.norm(core_out, z.reshape(-1, z.shape[-1]))
|
| 4184 |
+
core_out = core_out.reshape(z_shape).reshape(B, S, -1)
|
| 4185 |
+
return self.dropout(self.out_proj(core_out))
|
| 4186 |
+
|
| 4187 |
+
|
| 4188 |
+
# ==================== DECODER LAYER ==========================================
|
| 4189 |
+
|
| 4190 |
+
|
| 4191 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 4192 |
"""
|
| 4193 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
|
|
| 4210 |
self.layer_idx = layer_idx
|
| 4211 |
self.use_jtokm = config.use_jtokm
|
| 4212 |
|
| 4213 |
+
# ββ Token-mixer selection βββββββββββββββββββββββββββββββββββββββββ
|
| 4214 |
+
# use_linear_attention=True: replace full attention every
|
| 4215 |
+
# `linear_attention_every_n` layers (0-indexed pattern:
|
| 4216 |
+
# e.g. every_n=3 β layers 2, 5, 8, 11 β¦).
|
| 4217 |
+
# All other layers keep NeoLLMAttention unchanged.
|
| 4218 |
+
_every_n = getattr(config, "linear_attention_every_n", 3)
|
| 4219 |
+
self.is_linear_attn = (
|
| 4220 |
+
getattr(config, "use_linear_attention", False)
|
| 4221 |
+
and (layer_idx + 1) % _every_n == 0
|
| 4222 |
+
)
|
| 4223 |
+
if self.is_linear_attn:
|
| 4224 |
+
self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
|
| 4225 |
+
self.self_attn = None
|
| 4226 |
+
else:
|
| 4227 |
+
self.self_attn = NeoLLMAttention(config, layer_idx)
|
| 4228 |
+
self.linear_attn = None
|
| 4229 |
+
|
| 4230 |
self.mlp = (
|
| 4231 |
VersatileFFN(config)
|
| 4232 |
if getattr(config, "use_versatile_ffn", False)
|
|
|
|
| 4259 |
self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
|
| 4260 |
self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
|
| 4261 |
self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 4262 |
+
_num_blocks = getattr(config, 'attn_res_num_blocks', 0)
|
| 4263 |
+
self.attn_res_block_size = (
|
| 4264 |
+
max(config.num_hidden_layers // _num_blocks, 1) if _num_blocks > 0 else 1
|
| 4265 |
+
)
|
| 4266 |
else:
|
| 4267 |
self.attn_res_query_attn = None
|
| 4268 |
self.attn_res_query_mlp = None
|
| 4269 |
self.attn_res_norm = None
|
| 4270 |
+
self.attn_res_block_size = None
|
| 4271 |
+
|
| 4272 |
+
# ββ MUDD: separate K/V LayerNorms for qkvr+sepln mode ββββββββββββββ
|
| 4273 |
+
# Only instantiated when both mudd_dense_type='qkvr' AND mudd_sepln=True.
|
| 4274 |
+
# The existing input_layernorm handles the Q stream (unchanged).
|
| 4275 |
+
# Separate norms for K and V allow each stream to rescale independently.
|
| 4276 |
+
_use_mudd = getattr(config, 'use_mudd', False)
|
| 4277 |
+
_mudd_qkvr = getattr(config, 'mudd_dense_type', 'qkvr') == 'qkvr'
|
| 4278 |
+
_mudd_sepln = getattr(config, 'mudd_sepln', False)
|
| 4279 |
+
if _use_mudd and _mudd_qkvr and _mudd_sepln:
|
| 4280 |
+
self.mudd_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 4281 |
+
self.mudd_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 4282 |
+
else:
|
| 4283 |
+
self.mudd_k_norm = None
|
| 4284 |
+
self.mudd_v_norm = None
|
| 4285 |
+
|
| 4286 |
+
# ββ DCA (Heddes et al., 2025, arXiv:2502.06785) βββββββββββββββββββ
|
| 4287 |
+
# GRN-v3 module that aggregates the k-selected depth stack into 3
|
| 4288 |
+
# independent streams (Q, K, V). Each stream has its own dimension-
|
| 4289 |
+
# and input-dependent weights, enabling richer cross-layer interactions.
|
| 4290 |
+
# K and V get their own SeeDNorm + LNS norm chain (same scheme as
|
| 4291 |
+
# MUDD sepln) since they now arrive from a different aggregation path.
|
| 4292 |
+
# The residual connection uses the Q stream output (xq) as its base,
|
| 4293 |
+
# matching the DCA paper's decoder block design (residual = q_input).
|
| 4294 |
+
self.use_dca = getattr(config, 'use_dca', False)
|
| 4295 |
+
if self.use_dca:
|
| 4296 |
+
_dca_k = getattr(config, 'dca_k', 2)
|
| 4297 |
+
_num_stack = min(layer_idx + 1, 2 * _dca_k)
|
| 4298 |
+
self.dca_grn = NeoLLMGRN(
|
| 4299 |
+
hidden_size = config.hidden_size,
|
| 4300 |
+
num_stack_layers = _num_stack,
|
| 4301 |
+
num_outputs = 3,
|
| 4302 |
+
eps = config.rms_norm_eps,
|
| 4303 |
+
)
|
| 4304 |
+
self.dca_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 4305 |
+
self.dca_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 4306 |
+
else:
|
| 4307 |
+
self.dca_grn = None
|
| 4308 |
+
self.dca_k_norm = None
|
| 4309 |
+
self.dca_v_norm = None
|
| 4310 |
+
|
| 4311 |
+
# ββ StackTrans (Zhang et al., NeurIPS 2025) βββββββββββββββββββββββ
|
| 4312 |
+
# Differentiable multi-head hidden-state stack inserted at the very
|
| 4313 |
+
# beginning of the layer forward, before the attention sublayer.
|
| 4314 |
+
# Mutually exclusive with use_attn_res, use_mudd, use_dca.
|
| 4315 |
+
self.use_stacktrans = getattr(config, 'use_stacktrans', False)
|
| 4316 |
+
if self.use_stacktrans:
|
| 4317 |
+
self.stack_memory = StackMemory(config)
|
| 4318 |
+
else:
|
| 4319 |
+
self.stack_memory = None
|
| 4320 |
+
|
| 4321 |
+
# ββ LAuReL (Menghani, Kumar & Kumar, ICML 2025) βββββββββββββββββββ
|
| 4322 |
+
# Learned augmented residual connection replacing f(x)+x at both
|
| 4323 |
+
# the attention and MLP residual sums. Applied immediately before
|
| 4324 |
+
# GPAS, so GPAS still controls magnitude via stop-gradient scaling.
|
| 4325 |
+
# Two independent instances per layer (attention and MLP).
|
| 4326 |
+
# Compatible with use_stacktrans. Incompatible with MUDD/DCA/AttnRes.
|
| 4327 |
+
self.use_laurel = getattr(config, 'use_laurel', False)
|
| 4328 |
+
if self.use_laurel:
|
| 4329 |
+
self.laurel_attn = LAuReLLayer(config)
|
| 4330 |
+
self.laurel_mlp = LAuReLLayer(config)
|
| 4331 |
+
else:
|
| 4332 |
+
self.laurel_attn = None
|
| 4333 |
+
self.laurel_mlp = None
|
| 4334 |
|
| 4335 |
def _attn_res(
|
| 4336 |
self,
|
|
|
|
| 4380 |
B_vals: Optional[torch.Tensor] = None,
|
| 4381 |
attn_res_sources: Optional[list] = None,
|
| 4382 |
attn_res_partial: Optional[torch.Tensor] = None,
|
| 4383 |
+
mudd_streams: Optional[tuple] = None,
|
| 4384 |
+
dca_stack: Optional[torch.Tensor] = None,
|
| 4385 |
+
stack_state: Optional[torch.Tensor] = None,
|
| 4386 |
+
stack_mask: Optional[torch.Tensor] = None,
|
| 4387 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 4388 |
output_attentions: Optional[bool] = False,
|
| 4389 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
|
|
| 4393 |
if layer_analysis is not None:
|
| 4394 |
layer_analysis.hidden_states_input = hidden_states.detach()
|
| 4395 |
|
| 4396 |
+
# ββ StackTrans: hidden-state stack (pre-attention, pre-norm) βββββ
|
| 4397 |
+
# Executed first so attention sees the stack-enriched representation.
|
| 4398 |
+
# stack_state / stack_mask carry the stack from the previous layer;
|
| 4399 |
+
# both are None for layer 0 (StackMemory initialises to zeros then).
|
| 4400 |
+
# Mutually exclusive with MUDD / DCA / AttnRes β those branches are
|
| 4401 |
+
# all skipped when use_stacktrans=True (enforced in NeoLLMConfig).
|
| 4402 |
+
if self.use_stacktrans and self.stack_memory is not None:
|
| 4403 |
+
st_analysis = layer_analysis.stack if layer_analysis is not None else None
|
| 4404 |
+
hidden_states, stack_state, stack_mask = self.stack_memory(
|
| 4405 |
+
hidden_states, stack_state, stack_mask, analysis=st_analysis
|
| 4406 |
+
)
|
| 4407 |
+
|
| 4408 |
+
# ββ MUDD: unpack streams for Q/K/V/R (layer > 0 only) ββββββββββββ
|
| 4409 |
+
# mudd_streams is a 4-tuple (xq, xk, xv, xr) when use_mudd=True and
|
| 4410 |
+
# layer_idx > 0; None for layer 0 (standard residual there).
|
| 4411 |
+
# xr replaces hidden_states as the residual throughout this layer.
|
| 4412 |
+
# xq/xk/xv are the aggregated inputs for Q, K, V projections.
|
| 4413 |
+
# When mudd_dense_type='l' (single stream), all four are equal.
|
| 4414 |
+
# When mudd_sepln=True each stream has its own norm applied below.
|
| 4415 |
+
mudd_xk = None
|
| 4416 |
+
mudd_xv = None
|
| 4417 |
+
if mudd_streams is not None:
|
| 4418 |
+
xq_mudd, xk_mudd, xv_mudd, xr_mudd = mudd_streams
|
| 4419 |
+
# Replace hidden_states with xr for residual connections
|
| 4420 |
+
hidden_states = xr_mudd
|
| 4421 |
+
# Norm K and V streams β use separate SeeDNorm if sepln, else
|
| 4422 |
+
# they will share the main input_layernorm path via h_attn below
|
| 4423 |
+
if self.mudd_k_norm is not None:
|
| 4424 |
+
mudd_xk = self.lns_attn(self.mudd_k_norm(xk_mudd))
|
| 4425 |
+
mudd_xv = self.lns_attn(self.mudd_v_norm(xv_mudd))
|
| 4426 |
+
else:
|
| 4427 |
+
# No sepln: K/V also go through the Q-path norm chain
|
| 4428 |
+
mudd_xk = self.lns_attn(self.input_layernorm(xk_mudd))
|
| 4429 |
+
mudd_xv = self.lns_attn(self.input_layernorm(xv_mudd))
|
| 4430 |
+
# Override hidden_states for the Q path
|
| 4431 |
+
hidden_states_for_attn = xq_mudd
|
| 4432 |
+
else:
|
| 4433 |
+
hidden_states_for_attn = hidden_states
|
| 4434 |
+
|
| 4435 |
+
# ββ DCA: GRN-v3 depth-wise aggregation βββββββββββββββββββββββββββ
|
| 4436 |
+
# When active, runs the per-layer GRN on the k-selected depth stack
|
| 4437 |
+
# to produce three independent aggregated streams (Q, K, V).
|
| 4438 |
+
# xq replaces hidden_states as both the Q projection input AND the
|
| 4439 |
+
# post-attention residual (DCA paper: residual = q_input).
|
| 4440 |
+
# xk and xv go through separate SeeDNorm+LNS chains and are injected
|
| 4441 |
+
# into NeoLLMAttention via the existing mudd_xk/mudd_xv parameters.
|
| 4442 |
+
dca_residual = None
|
| 4443 |
+
dca_a = layer_analysis.dca if layer_analysis is not None else None
|
| 4444 |
+
if self.use_dca and dca_stack is not None:
|
| 4445 |
+
xq, xk, xv = self.dca_grn(dca_stack, analysis=dca_a)
|
| 4446 |
+
dca_residual = xq
|
| 4447 |
+
hidden_states_for_attn = xq
|
| 4448 |
+
# K and V streams: SeeDNorm + LNS before k_proj / v_proj
|
| 4449 |
+
# (reuses the mudd_xk/mudd_xv injection path in NeoLLMAttention)
|
| 4450 |
+
mudd_xk = self.lns_attn(self.dca_k_norm(xk))
|
| 4451 |
+
mudd_xv = self.lns_attn(self.dca_v_norm(xv))
|
| 4452 |
+
|
| 4453 |
# ββ Attention Residuals: compute pre-attention input ββββββββββββββ
|
| 4454 |
# When active, the input to the attention sublayer is no longer the
|
| 4455 |
# raw hidden_states (accumulated residual) but a softmax-weighted
|
|
|
|
| 4463 |
attn_res_sources, attn_res_partial, self.attn_res_query_attn,
|
| 4464 |
ar_analysis, "attn",
|
| 4465 |
)
|
| 4466 |
+
# ββ Block boundary fires HERE β after pre-attn, before attn sublayer ββ
|
| 4467 |
+
# Paper pseudocode (Fig. 2) timing: the completed partial of the previous
|
| 4468 |
+
# block is pushed to sources AFTER the pre-attn AttnRes call, so the first
|
| 4469 |
+
# layer of a new block still sees the old partial as an intra-block source
|
| 4470 |
+
# (no duplicate) and the new intra-block accumulation starts from zeros.
|
| 4471 |
+
if self.layer_idx > 0 and self.layer_idx % self.attn_res_block_size == 0:
|
| 4472 |
+
attn_res_sources.append(attn_res_partial) # in-place; outer loop sees this
|
| 4473 |
+
attn_res_partial = torch.zeros_like(attn_res_partial) # fresh delta start
|
| 4474 |
residual_attn = attn_res_partial
|
| 4475 |
else:
|
| 4476 |
+
h_attn = hidden_states_for_attn # MUDD/DCA: xq stream or unchanged
|
| 4477 |
+
# DCA: residual is xq (the GRN Q-stream output), not raw hidden_states
|
| 4478 |
+
residual_attn = dca_residual if dca_residual is not None else hidden_states
|
| 4479 |
|
| 4480 |
# ββ Attention block βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4481 |
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
|
|
|
|
| 4484 |
if layer_analysis is not None:
|
| 4485 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 4486 |
|
| 4487 |
+
if self.is_linear_attn:
|
| 4488 |
+
# ββ GatedDeltaNet linear attention path βββββββββββββββββββββββ
|
| 4489 |
+
# Does not use: first_layer_fan, mudd_xk/xv, attn_analysis.
|
| 4490 |
+
# attention_mask here is already the linear_attn_mask (no causal
|
| 4491 |
+
# bias, just padding) β NeoLLMModel.forward selects it per layer.
|
| 4492 |
+
hidden_states = self.linear_attn(
|
| 4493 |
+
hidden_states=h_lns,
|
| 4494 |
+
attention_mask=attention_mask,
|
| 4495 |
+
position_embeddings=position_embeddings,
|
| 4496 |
+
repo_rope_args=repo_rope_args,
|
| 4497 |
+
)
|
| 4498 |
+
attn_weights = None
|
| 4499 |
+
self.current_layer_fan = None
|
| 4500 |
+
else:
|
| 4501 |
+
# ββ Standard full attention path ββββββββββββββββββββββββββββββ
|
| 4502 |
+
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
| 4503 |
+
hidden_states=h_lns,
|
| 4504 |
+
attention_mask=attention_mask,
|
| 4505 |
+
position_embeddings=position_embeddings,
|
| 4506 |
+
first_layer_fan=first_layer_fan,
|
| 4507 |
+
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 4508 |
+
repo_rope_args=repo_rope_args,
|
| 4509 |
+
mudd_xk=mudd_xk,
|
| 4510 |
+
mudd_xv=mudd_xv,
|
| 4511 |
+
**kwargs,
|
| 4512 |
+
)
|
| 4513 |
|
| 4514 |
if layer_analysis is not None:
|
| 4515 |
layer_analysis.attn_contribution = hidden_states.detach()
|
| 4516 |
|
| 4517 |
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
|
| 4518 |
+
|
| 4519 |
+
# ββ Attention residual sum ββββββββββββββββββββββββββββββββββββββββ
|
| 4520 |
+
# Standard: GPAS(residual_attn + hidden_states)
|
| 4521 |
+
# LAuReL: GPAS(LAuReL(f_out=hidden_states, x_res=residual_attn))
|
| 4522 |
+
# Both paths feed into GPAS which applies stop-gradient scaling.
|
| 4523 |
+
if self.use_laurel and self.laurel_attn is not None:
|
| 4524 |
+
la_attn_a = layer_analysis.laurel_attn if layer_analysis is not None else None
|
| 4525 |
+
combined_attn = self.laurel_attn(hidden_states, residual_attn, analysis=la_attn_a)
|
| 4526 |
+
else:
|
| 4527 |
+
combined_attn = residual_attn + hidden_states
|
| 4528 |
+
|
| 4529 |
+
h_tilde = self.gpas_attn(combined_attn, analysis=gpas_attn_a)
|
| 4530 |
|
| 4531 |
if layer_analysis is not None:
|
| 4532 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
|
|
| 4562 |
if layer_analysis is not None:
|
| 4563 |
layer_analysis.mlp_contribution = delta_m.detach()
|
| 4564 |
|
| 4565 |
+
# ββ MLP residual sum ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4566 |
+
# LAuReL treats f(x) = delta_m [+ delta_r when JTok-M active] and
|
| 4567 |
+
# x_res = residual_mlp. JTok-M delta_r is additive alongside delta_m,
|
| 4568 |
+
# so the nonlinear component is delta_m + delta_r in that path.
|
| 4569 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 4570 |
if self.use_jtokm and z_tilde is not None and B_vals is not None:
|
| 4571 |
orig_shape = h_tilde.shape
|
| 4572 |
h_flat = h_tilde.reshape(-1, self.hidden_size)
|
|
|
|
| 4577 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 4578 |
delta_r = delta_r.reshape(orig_shape)
|
| 4579 |
|
| 4580 |
+
f_mlp = delta_m + delta_r # combined nonlinear term
|
| 4581 |
+
if self.use_laurel and self.laurel_mlp is not None:
|
| 4582 |
+
la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
|
| 4583 |
+
combined_mlp = self.laurel_mlp(f_mlp, residual_mlp, analysis=la_mlp_a)
|
| 4584 |
+
else:
|
| 4585 |
+
combined_mlp = residual_mlp + f_mlp
|
| 4586 |
+
hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
|
| 4587 |
else:
|
| 4588 |
+
aux_stats = None
|
| 4589 |
+
if self.use_laurel and self.laurel_mlp is not None:
|
| 4590 |
+
la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
|
| 4591 |
+
combined_mlp = self.laurel_mlp(delta_m, residual_mlp, analysis=la_mlp_a)
|
| 4592 |
+
else:
|
| 4593 |
+
combined_mlp = residual_mlp + delta_m
|
| 4594 |
+
hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
|
| 4595 |
|
| 4596 |
if layer_analysis is not None:
|
| 4597 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
|
|
| 4603 |
outputs += (aux_stats,)
|
| 4604 |
if versatile_aux is not None:
|
| 4605 |
outputs += (versatile_aux,)
|
| 4606 |
+
# StackTrans: always append stack state (None, None when inactive)
|
| 4607 |
+
# so NeoLLMModel.forward can extract them by position -2 and -1.
|
| 4608 |
+
outputs += (stack_state, stack_mask)
|
| 4609 |
return outputs
|
| 4610 |
|
| 4611 |
|
|
|
|
| 4933 |
if hasattr(module, "alpha_ma"):
|
| 4934 |
module.alpha_ma.zero_()
|
| 4935 |
|
| 4936 |
+
elif isinstance(module, NeoLLMGatedDeltaNet):
|
| 4937 |
+
module.dt_bias.data.fill_(1.0)
|
| 4938 |
+
module.A_log.data.uniform_(0, 16).log_()
|
| 4939 |
elif isinstance(module, GPAS):
|
| 4940 |
module.alpha.data.fill_(0.0)
|
| 4941 |
|
|
|
|
| 4992 |
module.attn_res_query_attn.data.zero_()
|
| 4993 |
module.attn_res_query_mlp.data.zero_()
|
| 4994 |
|
| 4995 |
+
elif isinstance(module, StackMemory):
|
| 4996 |
+
# Truncated-normal for all Linear weights (matches NeoLLM convention).
|
| 4997 |
+
# Biases zeroed. res_weight starts at 1.0 so the stack readout
|
| 4998 |
+
# contributes equally to the residual from step 0.
|
| 4999 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 5000 |
+
cutoff = getattr(self.config, "init_cutoff_factor", 3.0) * std
|
| 5001 |
+
for attr in ("down_proj", "up_proj", "action_head", "gate_proj"):
|
| 5002 |
+
layer = getattr(module, attr, None)
|
| 5003 |
+
if layer is not None and hasattr(layer, "weight"):
|
| 5004 |
+
nn.init.trunc_normal_(
|
| 5005 |
+
layer.weight, mean=0.0, std=std, a=-cutoff, b=cutoff
|
| 5006 |
+
)
|
| 5007 |
+
if layer.bias is not None:
|
| 5008 |
+
nn.init.zeros_(layer.bias)
|
| 5009 |
+
if hasattr(module, "res_weight"):
|
| 5010 |
+
module.res_weight.data.fill_(1.0)
|
| 5011 |
+
|
| 5012 |
+
elif isinstance(module, LAuReLLayer):
|
| 5013 |
+
# RW: raw logits initialised to zero β softmax([0,0]) = [0.5, 0.5].
|
| 5014 |
+
# The model quickly learns the optimal Ξ±,Ξ² weighting.
|
| 5015 |
+
# LR: lr_down (B, down-projection) β column orthogonal init,
|
| 5016 |
+
# as recommended by the LAuReL paper Β§3.3 for LLMs.
|
| 5017 |
+
# Column orthogonal preserves the L2 norm of the projected
|
| 5018 |
+
# representation, ensuring stable gradient magnitudes
|
| 5019 |
+
# through the low-rank bottleneck at init.
|
| 5020 |
+
# lr_up (A, up-projection) β zero init β lr_term = AΒ·Bx = 0
|
| 5021 |
+
# at step 0, so the module starts as a standard residual.
|
| 5022 |
+
# Gradient flows back through lr_down immediately via
|
| 5023 |
+
# chain rule; A learns from step 1 onward.
|
| 5024 |
+
if hasattr(module, "rw_logits"):
|
| 5025 |
+
nn.init.zeros_(module.rw_logits)
|
| 5026 |
+
if hasattr(module, "lr_down"):
|
| 5027 |
+
# Column-orthogonal: each column of weight^T is orthonormal.
|
| 5028 |
+
# nn.init.orthogonal_ produces a row-orthogonal matrix (rows
|
| 5029 |
+
# are orthonormal). Transposing gives column-orthogonal.
|
| 5030 |
+
nn.init.orthogonal_(module.lr_down.weight)
|
| 5031 |
+
if hasattr(module, "lr_up"):
|
| 5032 |
+
nn.init.zeros_(module.lr_up.weight)
|
| 5033 |
+
|
| 5034 |
elif isinstance(module, SpellingBeeEmbedding):
|
| 5035 |
# byte_emb initialised identically to token embeddings: std=1/βd.
|
| 5036 |
# Ensures E[βe_byteβΒ²] β 1 at init, matching etok, so the
|
|
|
|
| 5100 |
self.gradient_checkpointing = False
|
| 5101 |
self.first_layer_fan = None
|
| 5102 |
|
| 5103 |
+
# ββ StackTrans state flag ββββοΏ½οΏ½ββββββββββββββββββββββββββββββββββββ
|
| 5104 |
+
self.use_stacktrans = getattr(config, 'use_stacktrans', False)
|
| 5105 |
+
|
| 5106 |
+
# ββ Residual-replacement mutex ββββββββββββββββββββββββββββββββββββ
|
| 5107 |
+
# AttnRes, MUDD, and DCA all replace the residual aggregation
|
| 5108 |
+
# mechanism β at most one can be active at a time.
|
| 5109 |
+
_use_mudd = getattr(config, 'use_mudd', False)
|
| 5110 |
+
_use_attn_res = getattr(config, 'use_attn_res', False)
|
| 5111 |
+
_use_dca = getattr(config, 'use_dca', False)
|
| 5112 |
+
_active_count = sum([_use_mudd, _use_attn_res, _use_dca])
|
| 5113 |
+
if _active_count > 1:
|
| 5114 |
+
active = [n for n, f in [('use_mudd', _use_mudd),
|
| 5115 |
+
('use_attn_res', _use_attn_res),
|
| 5116 |
+
('use_dca', _use_dca)] if f]
|
| 5117 |
+
raise ValueError(
|
| 5118 |
+
f"use_mudd, use_attn_res, and use_dca are mutually exclusive β "
|
| 5119 |
+
f"got {active} simultaneously active. Set exactly one to True."
|
| 5120 |
+
)
|
| 5121 |
+
if _use_mudd:
|
| 5122 |
+
_mudd_dense_type = getattr(config, 'mudd_dense_type', 'qkvr')
|
| 5123 |
+
_mudd_dynamic = getattr(config, 'mudd_dynamic_dense', True)
|
| 5124 |
+
_mudd_round64 = getattr(config, 'mudd_round64', False)
|
| 5125 |
+
_mudd_expand_last = getattr(config, 'mudd_expand_last', False)
|
| 5126 |
+
_C = 4 if _mudd_dense_type == 'qkvr' else 1
|
| 5127 |
+
|
| 5128 |
+
# Static bias: one [C, lidx+2] parameter per layer.
|
| 5129 |
+
# Initialized with 1 at index [c, lidx+1] (identity on Xi) so that
|
| 5130 |
+
# at init (W2=0) each DA output = Xi β reducing to standard Transformer.
|
| 5131 |
+
_static_list = []
|
| 5132 |
+
for lidx in range(config.num_hidden_layers):
|
| 5133 |
+
# Last layer always uses C=1: its DA output is the final
|
| 5134 |
+
# model representation fed to the norm and lm_head, collapsing
|
| 5135 |
+
# all history into a single stream (paper code, both files).
|
| 5136 |
+
_c = 1 if lidx == config.num_hidden_layers - 1 else _C
|
| 5137 |
+
a = torch.zeros(_c, lidx + 2)
|
| 5138 |
+
a[:, lidx + 1] = 1.0 # last entry = current layer = identity
|
| 5139 |
+
_static_list.append(nn.Parameter(a))
|
| 5140 |
+
self.mudd_static = nn.ParameterList(_static_list)
|
| 5141 |
+
|
| 5142 |
+
# Dynamic DA modules (one per layer)
|
| 5143 |
+
if _mudd_dynamic:
|
| 5144 |
+
self.mudd_dynamic = nn.ModuleList([
|
| 5145 |
+
NeoLLMMUDDModule(
|
| 5146 |
+
hidden_size = config.hidden_size,
|
| 5147 |
+
lidx = lidx,
|
| 5148 |
+
# Last layer: C=1 β collapses to single final repr
|
| 5149 |
+
num_ways = 1 if lidx == config.num_hidden_layers - 1 else _C,
|
| 5150 |
+
is_last = (lidx == config.num_hidden_layers - 1),
|
| 5151 |
+
expand_last = _mudd_expand_last,
|
| 5152 |
+
round64 = _mudd_round64,
|
| 5153 |
+
)
|
| 5154 |
+
for lidx in range(config.num_hidden_layers)
|
| 5155 |
+
])
|
| 5156 |
+
else:
|
| 5157 |
+
self.mudd_dynamic = None
|
| 5158 |
+
else:
|
| 5159 |
+
self.mudd_static = None
|
| 5160 |
+
self.mudd_dynamic = None
|
| 5161 |
+
|
| 5162 |
+
# ββ DCA final GRN (Heddes et al., 2025) βββββββββββββββββββββββββββ
|
| 5163 |
+
# Applied once after all decoder layers to aggregate the full depth
|
| 5164 |
+
# stack into the final hidden representation before the output norm.
|
| 5165 |
+
# num_stack_layers = min(2*k, L+1) β same cap as per-layer GRNs.
|
| 5166 |
+
# num_outputs=1 collapses to a single [B, S, D] tensor.
|
| 5167 |
+
if _use_dca and getattr(config, 'dca_use_final_grn', True):
|
| 5168 |
+
_dca_k = getattr(config, 'dca_k', 2)
|
| 5169 |
+
_dca_eps = getattr(config, 'dca_grn_eps', config.rms_norm_eps)
|
| 5170 |
+
self.dca_final_grn = NeoLLMGRN(
|
| 5171 |
+
hidden_size = config.hidden_size,
|
| 5172 |
+
num_stack_layers = min(2 * _dca_k, config.num_hidden_layers + 1),
|
| 5173 |
+
num_outputs = 1,
|
| 5174 |
+
eps = _dca_eps,
|
| 5175 |
+
)
|
| 5176 |
+
else:
|
| 5177 |
+
self.dca_final_grn = None
|
| 5178 |
+
|
| 5179 |
self.post_init()
|
| 5180 |
|
| 5181 |
+
def _update_linear_attn_mask(
|
| 5182 |
+
self,
|
| 5183 |
+
attention_mask: Optional[torch.Tensor],
|
| 5184 |
+
) -> Optional[torch.Tensor]:
|
| 5185 |
+
"""
|
| 5186 |
+
Return mask for GatedDeltaNet layers (no causal bias, padding only).
|
| 5187 |
+
Returns None when all tokens are valid (GatedDeltaNet handles via
|
| 5188 |
+
_apply_mask_to_padding_states internally).
|
| 5189 |
+
"""
|
| 5190 |
+
if attention_mask is None:
|
| 5191 |
+
return None
|
| 5192 |
+
if torch.all(attention_mask == 1):
|
| 5193 |
+
return None
|
| 5194 |
+
return attention_mask
|
| 5195 |
+
|
| 5196 |
def get_input_embeddings(self):
|
| 5197 |
if self.config.use_token_generator:
|
| 5198 |
return self.token_generator
|
|
|
|
| 5218 |
getattr(cfg, "use_repo", False)
|
| 5219 |
and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
|
| 5220 |
)
|
| 5221 |
+
_versatile = getattr(cfg, "use_versatile_ffn", False)
|
| 5222 |
+
_use_stacktrans = getattr(cfg, "use_stacktrans", False)
|
| 5223 |
+
_use_laurel = getattr(cfg, "use_laurel", False)
|
| 5224 |
return LayerAnalysis(
|
| 5225 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 5226 |
seednorm_post_attn = SeeDNormAnalysis(),
|
|
|
|
| 5234 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 5235 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 5236 |
),
|
| 5237 |
+
gpas_attn = GPASAnalysis(),
|
| 5238 |
+
gpas_mlp = GPASAnalysis(),
|
| 5239 |
+
jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
|
| 5240 |
+
attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
|
| 5241 |
+
dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
|
| 5242 |
+
stack = StackMemoryAnalysis() if _use_stacktrans else None,
|
| 5243 |
+
laurel_attn = LAuReLAnalysis() if _use_laurel else None,
|
| 5244 |
+
laurel_mlp = LAuReLAnalysis() if _use_laurel else None,
|
| 5245 |
)
|
| 5246 |
|
| 5247 |
def forward(
|
|
|
|
| 5343 |
if getattr(self.config, "use_repo", False) else None
|
| 5344 |
)
|
| 5345 |
|
| 5346 |
+
# ββ Linear attention mask ββββββββββββββββββββββββββββββββββββββββββ
|
| 5347 |
+
# Computed once; each layer picks the appropriate mask below.
|
| 5348 |
+
linear_attn_mask = self._update_linear_attn_mask(attention_mask)
|
| 5349 |
+
|
| 5350 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 5351 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 5352 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
|
|
| 5361 |
if use_attn_res:
|
| 5362 |
attn_res_sources = [hidden_states] # b_0 = token embedding
|
| 5363 |
attn_res_partial = hidden_states # initial partial sum
|
| 5364 |
+
# Block boundary handling now lives inside NeoLLMDecoderLayer.forward(),
|
| 5365 |
+
# firing after the pre-attn AttnRes call (paper Fig. 2 timing).
|
| 5366 |
+
|
| 5367 |
+
# ββ MUDD state ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5368 |
+
# hiddens[0] = token embedding; hiddens[i] = output of layer i-1.
|
| 5369 |
+
# After each layer, its output is appended so layer i receives a
|
| 5370 |
+
# history of length i+2 (embedding + i preceding layer outputs).
|
| 5371 |
+
# mudd_streams is None for layer 0 (standard residual path there)
|
| 5372 |
+
# and a C-tuple of [B,T,D] tensors for layers 1β¦L.
|
| 5373 |
+
use_mudd = getattr(self.config, 'use_mudd', False)
|
| 5374 |
+
mudd_hiddens = None
|
| 5375 |
+
mudd_streams = None
|
| 5376 |
+
if use_mudd:
|
| 5377 |
+
mudd_hiddens = [hidden_states] # b_0 = token embedding
|
| 5378 |
+
|
| 5379 |
+
# ββ DCA state βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5380 |
+
# all_tokens[0] = token embedding; grows by one per decoder layer.
|
| 5381 |
+
# Before each layer, the stack is built and k-DCA selection applied,
|
| 5382 |
+
# capping memory at 2*dca_k stored tensors regardless of depth.
|
| 5383 |
+
# dca_stack is always non-None (even layer 0 gets [embedding]).
|
| 5384 |
+
use_dca = getattr(self.config, 'use_dca', False)
|
| 5385 |
+
_dca_k = getattr(self.config, 'dca_k', 2)
|
| 5386 |
+
dca_all_tokens = None
|
| 5387 |
+
dca_stack = None
|
| 5388 |
+
if use_dca:
|
| 5389 |
+
dca_all_tokens = [hidden_states] # [embedding]
|
| 5390 |
+
|
| 5391 |
+
# ββ StackTrans state ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5392 |
+
# stack_state / stack_mask start as None for the first layer;
|
| 5393 |
+
# StackMemory initialises them to zeros internally on first call.
|
| 5394 |
+
# After each layer, the returned (new_stack, new_mask) are passed
|
| 5395 |
+
# to the next layer as its initial stack β this is "vertical" state
|
| 5396 |
+
# propagation: information flows depth-wise through the stack.
|
| 5397 |
+
#
|
| 5398 |
+
# Temporal accumulation across generation steps is handled by the
|
| 5399 |
+
# StackMemory internal k_cache / action_cache mechanism:
|
| 5400 |
+
# - enable_cache is set True when use_cache=True (inference)
|
| 5401 |
+
# - reset_cache() is called when past_key_values is None
|
| 5402 |
+
# (new sequence, not a continuation step)
|
| 5403 |
+
# This matches the OLMo reference implementation exactly.
|
| 5404 |
+
use_stacktrans = self.use_stacktrans
|
| 5405 |
+
stack_state = None
|
| 5406 |
+
stack_mask = None
|
| 5407 |
+
if use_stacktrans:
|
| 5408 |
+
use_cache_flag = kwargs.get("use_cache", False)
|
| 5409 |
+
past_kv_flag = kwargs.get("past_key_values", None)
|
| 5410 |
+
for layer in self.layers:
|
| 5411 |
+
if layer.stack_memory is not None:
|
| 5412 |
+
layer.stack_memory.enable_cache = bool(use_cache_flag)
|
| 5413 |
+
if past_kv_flag is None:
|
| 5414 |
+
layer.stack_memory.reset_cache()
|
| 5415 |
|
| 5416 |
# Pre-allocate per-layer analysis list when analysis is active
|
| 5417 |
if analysis_state is not None:
|
|
|
|
| 5421 |
if output_hidden_states:
|
| 5422 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 5423 |
|
| 5424 |
+
# ββ DCA: build k-selected stack for this layer βββββββββββββββ
|
| 5425 |
+
# Stack has layer_idx+1 entries before selection; after k-DCA
|
| 5426 |
+
# selection it has at most 2*dca_k entries (first k + last k).
|
| 5427 |
+
if use_dca:
|
| 5428 |
+
dca_stack = dca_select_layers(
|
| 5429 |
+
torch.stack(dca_all_tokens, dim=0), k=_dca_k
|
| 5430 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5431 |
|
| 5432 |
# Build per-layer analysis container (only in eval + analysis mode)
|
| 5433 |
layer_analysis = None
|
|
|
|
| 5436 |
layer_analysis.layer_idx = layer_idx
|
| 5437 |
analysis_state.layers.append(layer_analysis)
|
| 5438 |
|
| 5439 |
+
# Select the appropriate mask: causal for full attention,
|
| 5440 |
+
# padding-only for GatedDeltaNet linear attention.
|
| 5441 |
+
_layer_mask = (
|
| 5442 |
+
linear_attn_mask
|
| 5443 |
+
if getattr(decoder_layer, "is_linear_attn", False)
|
| 5444 |
+
else causal_mask
|
| 5445 |
+
)
|
| 5446 |
layer_outputs = decoder_layer(
|
| 5447 |
hidden_states,
|
| 5448 |
position_embeddings=position_embeddings,
|
| 5449 |
+
attention_mask=_layer_mask,
|
| 5450 |
first_layer_fan=self.first_layer_fan,
|
| 5451 |
z_tilde=z_tilde,
|
| 5452 |
B_vals=B_vals,
|
| 5453 |
attn_res_sources=attn_res_sources,
|
| 5454 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
| 5455 |
+
mudd_streams=mudd_streams,
|
| 5456 |
+
dca_stack=dca_stack,
|
| 5457 |
+
stack_state=stack_state,
|
| 5458 |
+
stack_mask=stack_mask,
|
| 5459 |
layer_analysis=layer_analysis,
|
| 5460 |
output_attentions=output_attentions,
|
| 5461 |
repo_rope_args=repo_rope_args,
|
|
|
|
| 5463 |
)
|
| 5464 |
hidden_states = layer_outputs[0]
|
| 5465 |
|
| 5466 |
+
# ββ StackTrans: extract updated stack state for next layer βββββ
|
| 5467 |
+
# layer_outputs always ends with (stack_state, stack_mask) β
|
| 5468 |
+
# both are None when use_stacktrans=False (zero cost).
|
| 5469 |
+
stack_state = layer_outputs[-2]
|
| 5470 |
+
stack_mask = layer_outputs[-1]
|
| 5471 |
+
|
| 5472 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 5473 |
if use_attn_res:
|
| 5474 |
attn_res_partial = hidden_states
|
| 5475 |
|
| 5476 |
+
# Append layer output to DCA history for next layer's stack
|
| 5477 |
+
if use_dca:
|
| 5478 |
+
dca_all_tokens.append(hidden_states)
|
| 5479 |
+
|
| 5480 |
+
# ββ MUDD: append current output and compute DA for next layer ββ
|
| 5481 |
+
# mudd_hiddens grows by 1 each iteration; at layer i it has i+2
|
| 5482 |
+
# entries (embedding + i outputs). The DA for layer i+1 takes this
|
| 5483 |
+
# full history and produces C streams via dynamic + static weights.
|
| 5484 |
+
# mudd_streams is passed to layer i+1 as its input streams.
|
| 5485 |
+
if use_mudd:
|
| 5486 |
+
mudd_hiddens.append(hidden_states)
|
| 5487 |
+
# Compute DA module output using the just-appended history
|
| 5488 |
+
# (mudd_hiddens now has layer_idx+2 entries)
|
| 5489 |
+
is_last_layer = (layer_idx == self.config.num_hidden_layers - 1)
|
| 5490 |
+
mudd_da_module = self.mudd_dynamic[layer_idx] if self.mudd_dynamic is not None else None
|
| 5491 |
+
if mudd_da_module is not None:
|
| 5492 |
+
raw_streams = mudd_da_module(
|
| 5493 |
+
hidden_states,
|
| 5494 |
+
mudd_hiddens,
|
| 5495 |
+
self.mudd_static[layer_idx],
|
| 5496 |
+
)
|
| 5497 |
+
else:
|
| 5498 |
+
# Static-only: apply weighted sum with learnable bias
|
| 5499 |
+
# stack history [L, B, T, D], weight by mudd_static
|
| 5500 |
+
stacked = torch.stack(mudd_hiddens, dim=0) # [L, B, T, D]
|
| 5501 |
+
a = self.mudd_static[layer_idx].to(hidden_states.dtype) # [C, L]
|
| 5502 |
+
raw_streams = tuple(
|
| 5503 |
+
torch.einsum('cl,lbtd->btd', a[c:c+1], stacked).squeeze(0)
|
| 5504 |
+
for c in range(a.shape[0])
|
| 5505 |
+
)
|
| 5506 |
+
if is_last_layer:
|
| 5507 |
+
# Last layer DA always produces C=1 β single final repr.
|
| 5508 |
+
# This is the MUDD-aggregated combination of all layer
|
| 5509 |
+
# histories weighted by the last layer's output as query.
|
| 5510 |
+
# Replace hidden_states so the final norm and lm_head
|
| 5511 |
+
# receive this aggregated representation, not the raw
|
| 5512 |
+
# last-layer output (paper forward loop: x = x[0] after loop).
|
| 5513 |
+
hidden_states = raw_streams[0]
|
| 5514 |
+
mudd_streams = None # no next layer
|
| 5515 |
+
elif len(raw_streams) == 1:
|
| 5516 |
+
# dense_type='l': broadcast to 4-tuple
|
| 5517 |
+
mudd_streams = (raw_streams[0],) * 4
|
| 5518 |
+
else:
|
| 5519 |
+
# 'qkvr': 4 streams β (xq, xk, xv, xr)
|
| 5520 |
+
mudd_streams = raw_streams
|
| 5521 |
+
|
| 5522 |
if output_attentions:
|
| 5523 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 5524 |
|
| 5525 |
+
# Collect JTok-M / VersatileFFN aux stats.
|
| 5526 |
+
# layer_outputs always ends with (stack_state, stack_mask) β
|
| 5527 |
+
# slice [1:-2] to skip hidden_states[0] and the two stack slots.
|
| 5528 |
+
inner_outputs = layer_outputs[1:-2]
|
| 5529 |
+
|
| 5530 |
+
if self.config.use_jtokm and len(inner_outputs) > (1 if output_attentions else 0):
|
| 5531 |
+
all_aux_stats.append(inner_outputs[-1])
|
| 5532 |
|
|
|
|
|
|
|
| 5533 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 5534 |
+
for item in inner_outputs:
|
| 5535 |
if isinstance(item, tuple) and len(item) == 3:
|
|
|
|
| 5536 |
all_aux_stats.append(("versatile", item))
|
| 5537 |
break
|
| 5538 |
|
|
|
|
| 5540 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 5541 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 5542 |
|
| 5543 |
+
# ββ DCA final GRN ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5544 |
+
# Aggregates the full depth history (k-selected) into the final
|
| 5545 |
+
# hidden representation, matching the DCAGPT forward loop which
|
| 5546 |
+
# applies final_grn(stack(all_tokens)) before norm β lm_head.
|
| 5547 |
+
if use_dca and self.dca_final_grn is not None:
|
| 5548 |
+
final_stack = dca_select_layers(
|
| 5549 |
+
torch.stack(dca_all_tokens, dim=0), k=_dca_k
|
| 5550 |
+
)
|
| 5551 |
+
hidden_states = self.dca_final_grn(final_stack)
|
| 5552 |
+
|
| 5553 |
hidden_states = self.norm(hidden_states)
|
| 5554 |
|
| 5555 |
if output_hidden_states:
|
|
|
|
| 5562 |
analysis_state.attn_res_sources_final = (
|
| 5563 |
attn_res_sources if use_attn_res else None
|
| 5564 |
)
|
| 5565 |
+
analysis_state.dca_all_tokens_final = (
|
| 5566 |
+
dca_all_tokens if use_dca else None
|
| 5567 |
+
)
|
| 5568 |
|
| 5569 |
if not return_dict:
|
| 5570 |
return tuple(
|
|
|
|
| 5705 |
layers = None, # filled by NeoLLMModel.forward
|
| 5706 |
jtokm_aux_stats = [] if cfg.use_jtokm else None,
|
| 5707 |
attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
|
| 5708 |
+
dca_all_tokens_final = [] if getattr(cfg, "use_dca", False) else None,
|
| 5709 |
)
|
| 5710 |
|
| 5711 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 5843 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 5844 |
|
| 5845 |
__all__ = [
|
| 5846 |
+
"NeoLLMGatedDeltaNet",
|
| 5847 |
+
"StackMemory",
|
| 5848 |
+
"LAuReLLayer",
|
| 5849 |
+
"NeoLLMMUDDModule",
|
| 5850 |
+
"NeoLLMGRN",
|
| 5851 |
+
"dca_select_layers",
|
| 5852 |
"NeoLLMForCausalLM",
|
| 5853 |
"NeoLLMModel",
|
| 5854 |
"NeoLLMPreTrainedModel",
|
|
|
|
| 5866 |
"REPOModule",
|
| 5867 |
"VersatileFFN",
|
| 5868 |
"compute_versatile_aux_loss",
|
| 5869 |
+
# Analysis dataclasses
|
| 5870 |
"AnalysisState",
|
| 5871 |
"LayerAnalysis",
|
| 5872 |
"AttentionAnalysis",
|
|
|
|
| 5880 |
"VersatileFFNAnalysis",
|
| 5881 |
"JTokMAnalysis",
|
| 5882 |
"AttnResAnalysis",
|
| 5883 |
+
"DCAAnalysis",
|
| 5884 |
+
"StackMemoryAnalysis",
|
| 5885 |
+
"LAuReLAnalysis",
|
| 5886 |
"GeneratorAnalysis",
|
| 5887 |
]
|
| 5888 |
|