Upload modeling_neollm.py
Browse files- modeling_neollm.py +45 -1200
modeling_neollm.py
CHANGED
|
@@ -80,7 +80,7 @@ from transformers.utils import TransformersKwargs, logging
|
|
| 80 |
from configuration_neollm import NeoLLMConfig
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
-
|
| 84 |
logger = logging.get_logger(__name__)
|
| 85 |
|
| 86 |
|
|
@@ -339,22 +339,6 @@ class JTokMAnalysis:
|
|
| 339 |
lns_scale: Optional[float] = None # 1/β(2β) scaling factor
|
| 340 |
|
| 341 |
|
| 342 |
-
@dataclass
|
| 343 |
-
class DCAAnalysis:
|
| 344 |
-
"""
|
| 345 |
-
GRN-v3 depth-wise aggregate weights from a DeepCrossAttention layer.
|
| 346 |
-
Only populated when use_dca=True.
|
| 347 |
-
|
| 348 |
-
grn_depth_weights: softmax-free aggregate scalars used to weight each
|
| 349 |
-
source layer, shape [3, y, B, S] where 3 = Q/K/V streams,
|
| 350 |
-
y = selected stack depth (at most 2*dca_k), B = batch, S = seq.
|
| 351 |
-
These are the per-position, per-layer scalars *before* adding the
|
| 352 |
-
static bias β useful to see which layers the dynamic component
|
| 353 |
-
selectively suppresses (ReLU zeros out negative entries).
|
| 354 |
-
"""
|
| 355 |
-
grn_depth_weights: Optional[torch.Tensor] = None # [3, y, B, S]
|
| 356 |
-
|
| 357 |
-
|
| 358 |
@dataclass
|
| 359 |
class AttnResAnalysis:
|
| 360 |
"""
|
|
@@ -366,78 +350,6 @@ class AttnResAnalysis:
|
|
| 366 |
sources_count: Optional[int] = None # number of sources including partial
|
| 367 |
|
| 368 |
|
| 369 |
-
@dataclass
|
| 370 |
-
class StackMemoryAnalysis:
|
| 371 |
-
"""
|
| 372 |
-
Internals of a StackMemory forward pass.
|
| 373 |
-
Only populated when use_stacktrans=True AND model is in eval + analysis mode.
|
| 374 |
-
|
| 375 |
-
Reference: Zhang, K. et al. (NeurIPS 2025). "Recursive Transformer:
|
| 376 |
-
Boosting Reasoning Ability with State Stack."
|
| 377 |
-
|
| 378 |
-
action_probs: softmax distribution [push, pop, no-op] per head and
|
| 379 |
-
token position. Shape [B, S, H, 3]. Visualising this
|
| 380 |
-
across layers reveals the push-heavy early layers and
|
| 381 |
-
pop-heavy later layers described in the paper (Β§B.2).
|
| 382 |
-
stack_in: stack state entering this layer (the output of the
|
| 383 |
-
previous layer's StackMemory). Shape [B, H, slots, ds].
|
| 384 |
-
None for layer 0 (starts as all-zeros).
|
| 385 |
-
stack_out: updated stack state after processing this sequence.
|
| 386 |
-
Shape [B, H, slots, ds]. This is new_stack[:, -1] β
|
| 387 |
-
the stack at the final sequence position, passed to
|
| 388 |
-
the next layer as stack_in.
|
| 389 |
-
mask_out: validity mask for stack_out. Shape [B, H, slots].
|
| 390 |
-
Values near 1 indicate active slots; near 0 = empty.
|
| 391 |
-
gate_weights: softmax attention weights used for global reading.
|
| 392 |
-
Shape [B, S, H, slots]. High weight on slot i at
|
| 393 |
-
position t means the model retrieved from slot i there.
|
| 394 |
-
memory_output: weighted stack readout before up_proj.
|
| 395 |
-
Shape [B, S, stack_d_model].
|
| 396 |
-
residual_scale: value of the learnable res_weight scalar at this step.
|
| 397 |
-
"""
|
| 398 |
-
action_probs: Optional[torch.Tensor] = None # [B,S,H,3]
|
| 399 |
-
stack_in: Optional[torch.Tensor] = None # [B,H,slots,ds] entering layer
|
| 400 |
-
stack_out: Optional[torch.Tensor] = None # [B,H,slots,ds] leaving layer
|
| 401 |
-
mask_out: Optional[torch.Tensor] = None # [B,H,slots]
|
| 402 |
-
gate_weights: Optional[torch.Tensor] = None # [B,S,H,slots]
|
| 403 |
-
memory_output: Optional[torch.Tensor] = None # [B,S,stack_d_model]
|
| 404 |
-
residual_scale: Optional[float] = None # res_weight scalar
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
@dataclass
|
| 408 |
-
class LAuReLAnalysis:
|
| 409 |
-
"""
|
| 410 |
-
Internals of one LAuReL residual connection forward pass.
|
| 411 |
-
Only populated when use_laurel=True AND model is in eval + analysis mode.
|
| 412 |
-
Instantiated twice per layer: once for the attention residual, once for MLP.
|
| 413 |
-
|
| 414 |
-
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 415 |
-
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 416 |
-
|
| 417 |
-
Math (combined RW+LR, both sub-variants active):
|
| 418 |
-
|
| 419 |
-
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 420 |
-
|
| 421 |
-
where [Ξ±, Ξ²] = softmax([a, b]), a,b β β learnable (RW component),
|
| 422 |
-
B β β^{rΓD} column-orthogonal init, A β β^{DΓr} zero init (LR component).
|
| 423 |
-
At step 0: A=0 β lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
|
| 424 |
-
or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
|
| 425 |
-
|
| 426 |
-
Fields:
|
| 427 |
-
alpha_rw: softmax(a) β weight on f(x_i). [scalar float]
|
| 428 |
-
None when use_laurel_rw=False.
|
| 429 |
-
beta_rw: softmax(b) β weight on g(x_i). [scalar float]
|
| 430 |
-
None when use_laurel_rw=False.
|
| 431 |
-
lr_term: AΒ·(BΒ·x_res) β the low-rank residual augmentation.
|
| 432 |
-
Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
|
| 433 |
-
output: Final combined tensor before GPAS. Shape [B, S, D].
|
| 434 |
-
"""
|
| 435 |
-
alpha_rw: Optional[float] = None # softmax weight on f(x)
|
| 436 |
-
beta_rw: Optional[float] = None # softmax weight on g(x)
|
| 437 |
-
lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
|
| 438 |
-
output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
|
| 439 |
-
|
| 440 |
-
|
| 441 |
@dataclass
|
| 442 |
class LayerAnalysis:
|
| 443 |
"""
|
|
@@ -466,12 +378,8 @@ class LayerAnalysis:
|
|
| 466 |
gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
|
| 467 |
|
| 468 |
# Optional components (None when inactive)
|
| 469 |
-
jtokm:
|
| 470 |
-
attn_res:
|
| 471 |
-
dca: Optional[DCAAnalysis] = None # if use_dca
|
| 472 |
-
stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
|
| 473 |
-
laurel_attn: Optional[LAuReLAnalysis] = None # if use_laurel (attention residual)
|
| 474 |
-
laurel_mlp: Optional[LAuReLAnalysis] = None # if use_laurel (MLP residual)
|
| 475 |
|
| 476 |
|
| 477 |
@dataclass
|
|
@@ -536,7 +444,6 @@ class AnalysisState:
|
|
| 536 |
layers: Optional[List[LayerAnalysis]] = None
|
| 537 |
jtokm_aux_stats: Optional[list] = None
|
| 538 |
attn_res_sources_final: Optional[list] = None
|
| 539 |
-
dca_all_tokens_final: Optional[list] = None
|
| 540 |
logits: Optional[torch.Tensor] = None
|
| 541 |
|
| 542 |
class ScalarMultiplier(nn.Module):
|
|
@@ -2456,8 +2363,6 @@ class NeoLLMAttention(nn.Module):
|
|
| 2456 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2457 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
| 2458 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 2459 |
-
mudd_xk: Optional[torch.Tensor] = None,
|
| 2460 |
-
mudd_xv: Optional[torch.Tensor] = None,
|
| 2461 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2462 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2463 |
input_shape = hidden_states.shape[:-1]
|
|
@@ -2468,14 +2373,6 @@ class NeoLLMAttention(nn.Module):
|
|
| 2468 |
h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
|
| 2469 |
current_layer_fan = h_fan.clone()
|
| 2470 |
|
| 2471 |
-
# ββ MUDD: separate K/V FAN paths βββββββββββββββββββββββββββββββββ
|
| 2472 |
-
# When mudd_xk/mudd_xv are provided (MUDD qkvr mode), they have already
|
| 2473 |
-
# been normalized by the decoder layer's K/V norm chain. Here they go
|
| 2474 |
-
# through their own FAN transform before k_proj/v_proj, keeping the
|
| 2475 |
-
# FANformer periodicity modeling orthogonally intact per stream.
|
| 2476 |
-
h_fan_k = self.fan_layer(mudd_xk) if mudd_xk is not None else h_fan
|
| 2477 |
-
h_fan_v = self.fan_layer(mudd_xv) if mudd_xv is not None else h_fan
|
| 2478 |
-
|
| 2479 |
query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
|
| 2480 |
kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
|
| 2481 |
|
|
@@ -2490,8 +2387,8 @@ class NeoLLMAttention(nn.Module):
|
|
| 2490 |
attn_analysis.gate_raw = gate.detach()
|
| 2491 |
|
| 2492 |
q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
|
| 2493 |
-
k = self.k_norm(self.k_proj(
|
| 2494 |
-
v = self.v_proj(
|
| 2495 |
|
| 2496 |
if attn_analysis is not None:
|
| 2497 |
attn_analysis.q_post_norm = q.detach()
|
|
@@ -3168,651 +3065,6 @@ class NeoLLMMLP(nn.Module):
|
|
| 3168 |
return result
|
| 3169 |
|
| 3170 |
|
| 3171 |
-
class NeoLLMMUDDModule(nn.Module):
|
| 3172 |
-
"""
|
| 3173 |
-
Multiway Dynamic Dense (MUDD) Depth-wise Aggregate module.
|
| 3174 |
-
|
| 3175 |
-
Generates per-position, per-stream connection weights over all preceding
|
| 3176 |
-
layer outputs (and the token embedding) and produces up to C=4 aggregated
|
| 3177 |
-
streams (Q, K, V, R) for the next Transformer block.
|
| 3178 |
-
|
| 3179 |
-
Architecture (Xiao et al., 2025, arXiv:2502.12170):
|
| 3180 |
-
dw = GELU(RMSNorm(x) @ W1) @ W2 + a # [B, T, C*(lidx+2)]
|
| 3181 |
-
dw = reshape to [C, B, T, (lidx+2)]
|
| 3182 |
-
stream_c = Ξ£_j dw[c, :, :, j] * hiddens[j] for c in range(C)
|
| 3183 |
-
|
| 3184 |
-
W1 ~ N(0, 1/D), W2 = 0, a = identity on last index β reduces to standard
|
| 3185 |
-
Transformer at init (dynamic part is zero, static bias selects Xi).
|
| 3186 |
-
|
| 3187 |
-
Args:
|
| 3188 |
-
hidden_size: model dimension D
|
| 3189 |
-
lidx: layer index (0-based); history has lidx+2 entries
|
| 3190 |
-
num_ways: C, number of output streams (4 for "qkvr", 1 for "l")
|
| 3191 |
-
is_last: whether this is the last layer (controls expand_last)
|
| 3192 |
-
expand_last: multiply hid_dim by 4 for the final layer's DA module
|
| 3193 |
-
round64: round hid_dim up to the nearest multiple of 64
|
| 3194 |
-
"""
|
| 3195 |
-
|
| 3196 |
-
def __init__(
|
| 3197 |
-
self,
|
| 3198 |
-
hidden_size: int,
|
| 3199 |
-
lidx: int,
|
| 3200 |
-
num_ways: int = 4,
|
| 3201 |
-
is_last: bool = False,
|
| 3202 |
-
expand_last: bool = False,
|
| 3203 |
-
round64: bool = False,
|
| 3204 |
-
) -> None:
|
| 3205 |
-
super().__init__()
|
| 3206 |
-
self.lidx = lidx
|
| 3207 |
-
self.num_ways = num_ways
|
| 3208 |
-
l = lidx + 2 # history length: embedding + lidx layers
|
| 3209 |
-
hid_dim = l * num_ways
|
| 3210 |
-
out_dim = l * num_ways
|
| 3211 |
-
if is_last and expand_last:
|
| 3212 |
-
hid_dim *= 4
|
| 3213 |
-
if round64:
|
| 3214 |
-
hid_dim = (hid_dim // 64 + 1) * 64
|
| 3215 |
-
# RMSNorm without learnable scale (paper uses RMSnormNoscale)
|
| 3216 |
-
self.norm = nn.RMSNorm(hidden_size, elementwise_affine=False,
|
| 3217 |
-
eps=1e-6)
|
| 3218 |
-
self.w1 = nn.Linear(hidden_size, hid_dim, bias=False)
|
| 3219 |
-
self.act = nn.GELU()
|
| 3220 |
-
self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
|
| 3221 |
-
self._reset_mudd_parameters(hidden_size)
|
| 3222 |
-
|
| 3223 |
-
def _reset_mudd_parameters(self, D: int) -> None:
|
| 3224 |
-
# W1 ~ N(0, 1/D); W2 = 0 β dynamic part starts at zero
|
| 3225 |
-
nn.init.normal_(self.w1.weight, mean=0.0, std=1.0 / D)
|
| 3226 |
-
nn.init.zeros_(self.w2.weight)
|
| 3227 |
-
|
| 3228 |
-
def forward(
|
| 3229 |
-
self,
|
| 3230 |
-
x: torch.Tensor, # [B, T, D] β current layer output (Xi)
|
| 3231 |
-
hiddens: list, # list of lidx+2 tensors [B, T, D]
|
| 3232 |
-
static_bias: torch.Tensor, # [C, lidx+2] β learnable static prior
|
| 3233 |
-
) -> tuple:
|
| 3234 |
-
"""
|
| 3235 |
-
Returns:
|
| 3236 |
-
Tuple of num_ways tensors, each [B, T, D] β the aggregated streams.
|
| 3237 |
-
"""
|
| 3238 |
-
B, T, D = x.shape
|
| 3239 |
-
# Dynamic weight generation: [B, T, C*(lidx+2)]
|
| 3240 |
-
dw = self.w2(self.act(self.w1(self.norm(x))))
|
| 3241 |
-
# Add static bias (broadcast over B and T)
|
| 3242 |
-
# static_bias: [C, L] β [1, 1, C*L] via reshape
|
| 3243 |
-
C, L = static_bias.shape
|
| 3244 |
-
dw = dw + static_bias.reshape(1, 1, C * L).to(dw.dtype)
|
| 3245 |
-
# Reshape to [C, B, T, L]
|
| 3246 |
-
dw = dw.view(B, T, C, L).permute(2, 0, 1, 3) # [C, B, T, L]
|
| 3247 |
-
# Stack history: [L, B, T, D]
|
| 3248 |
-
stacked = torch.stack(hiddens, dim=0) # [L, B, T, D]
|
| 3249 |
-
# Aggregate: Ξ£_j dw[c, :, :, j] * hiddens[j]
|
| 3250 |
-
# einsum "cbtl, lbtd -> cbtd"
|
| 3251 |
-
streams = torch.einsum('cbtl,lbtd->cbtd', dw, stacked) # [C, B, T, D]
|
| 3252 |
-
return tuple(streams[c] for c in range(C))
|
| 3253 |
-
|
| 3254 |
-
|
| 3255 |
-
def dca_select_layers(stacked: torch.Tensor, k: int) -> torch.Tensor:
|
| 3256 |
-
"""
|
| 3257 |
-
k-DCA layer selection (Heddes et al., 2025, Β§3.1).
|
| 3258 |
-
|
| 3259 |
-
Keeps only the first k and last k tensors from the depth stack,
|
| 3260 |
-
capping memory at 2k layer representations regardless of depth.
|
| 3261 |
-
When the stack has <= 2k entries all are kept (early layers).
|
| 3262 |
-
|
| 3263 |
-
Args:
|
| 3264 |
-
stacked: [y, B, S, D] β stack of all layer outputs so far.
|
| 3265 |
-
k: number of first/last layers to retain.
|
| 3266 |
-
Returns:
|
| 3267 |
-
[min(y, 2k), B, S, D]
|
| 3268 |
-
"""
|
| 3269 |
-
y = stacked.shape[0]
|
| 3270 |
-
if y <= k * 2:
|
| 3271 |
-
return stacked
|
| 3272 |
-
return torch.cat([stacked[:k], stacked[-k:]], dim=0)
|
| 3273 |
-
|
| 3274 |
-
|
| 3275 |
-
class NeoLLMGRN(nn.Module):
|
| 3276 |
-
"""
|
| 3277 |
-
Generalized Residual Network v3 (GRN-v3) from DeepCrossAttention
|
| 3278 |
-
(Heddes et al., 2025, arXiv:2502.06785, Β§3.1).
|
| 3279 |
-
|
| 3280 |
-
Produces `num_outputs` aggregated streams from a depth-wise stack of
|
| 3281 |
-
layer representations. Weights are simultaneously:
|
| 3282 |
-
|
| 3283 |
-
- **Input-dependent** (dynamic): a two-layer mapping
|
| 3284 |
-
``wΜ = ReLU(RMSNorm(G) @ W)`` produces one scalar per
|
| 3285 |
-
(output-stream, depth-position, batch-token). ``W`` is initialized
|
| 3286 |
-
to zero so the dynamic contribution starts neutral.
|
| 3287 |
-
- **Dimension-dependent** (static): a learnable bias ``b`` of shape
|
| 3288 |
-
``[num_outputs, num_stack_layers, hidden_size]`` initialized to ones
|
| 3289 |
-
provides a per-dimension, per-layer prior. At initialization the
|
| 3290 |
-
dynamic part is zero and the static bias sums to an equal-weight
|
| 3291 |
-
average over all stack entries, reducing to a standard residual mean.
|
| 3292 |
-
|
| 3293 |
-
The combined weight for output stream ``o``, stack position ``y``,
|
| 3294 |
-
batch ``b``, token ``n``, feature ``d`` is::
|
| 3295 |
-
|
| 3296 |
-
weight[o, y, b, n, d] = ReLU(dynamic[y, b, n, o]) + bias[o, y, d]
|
| 3297 |
-
|
| 3298 |
-
Output ``o`` is then the weighted sum over depth::
|
| 3299 |
-
|
| 3300 |
-
out[o, b, n, d] = Ξ£_y stack[y, b, n, d] * weight[o, y, b, n, d]
|
| 3301 |
-
|
| 3302 |
-
Reference:
|
| 3303 |
-
Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
|
| 3304 |
-
Transformer Residual Connections.* arXiv:2502.06785.
|
| 3305 |
-
|
| 3306 |
-
Args:
|
| 3307 |
-
hidden_size: model dimension D.
|
| 3308 |
-
num_stack_layers: number of depth entries this GRN will receive
|
| 3309 |
-
(= min(layer_idx+1, 2*dca_k)).
|
| 3310 |
-
num_outputs: number of output streams (3 for DCA Q/K/V,
|
| 3311 |
-
1 for the final aggregation GRN).
|
| 3312 |
-
eps: epsilon for the internal RMSNorm.
|
| 3313 |
-
"""
|
| 3314 |
-
|
| 3315 |
-
def __init__(
|
| 3316 |
-
self,
|
| 3317 |
-
hidden_size: int,
|
| 3318 |
-
num_stack_layers: int,
|
| 3319 |
-
num_outputs: int = 3,
|
| 3320 |
-
eps: float = 1e-6,
|
| 3321 |
-
) -> None:
|
| 3322 |
-
super().__init__()
|
| 3323 |
-
self.num_outputs = num_outputs
|
| 3324 |
-
self.num_stack_layers = num_stack_layers
|
| 3325 |
-
|
| 3326 |
-
# Dynamic component: RMSNorm(no scale) β Linear β ReLU
|
| 3327 |
-
# Linear maps D β num_outputs; init zeros so dynamic part = 0 at step 0.
|
| 3328 |
-
_linear = nn.Linear(hidden_size, num_outputs, bias=False)
|
| 3329 |
-
nn.init.zeros_(_linear.weight)
|
| 3330 |
-
self.norm_noscale = nn.RMSNorm(
|
| 3331 |
-
hidden_size, eps=eps, elementwise_affine=False
|
| 3332 |
-
)
|
| 3333 |
-
self.to_dynamic = nn.Sequential(_linear, nn.ReLU())
|
| 3334 |
-
|
| 3335 |
-
# Static bias: [num_outputs, num_stack_layers, hidden_size], init ones.
|
| 3336 |
-
# At init: weight = 0 + bias = 1 per entry β equal-weight average β residual.
|
| 3337 |
-
self.bias = nn.Parameter(
|
| 3338 |
-
torch.ones(num_outputs, num_stack_layers, hidden_size)
|
| 3339 |
-
)
|
| 3340 |
-
|
| 3341 |
-
def forward(
|
| 3342 |
-
self,
|
| 3343 |
-
stack: torch.Tensor,
|
| 3344 |
-
analysis: Optional["DCAAnalysis"] = None,
|
| 3345 |
-
) -> tuple:
|
| 3346 |
-
"""
|
| 3347 |
-
Args:
|
| 3348 |
-
stack: [y, B, S, D] β selected depth stack (y β€ 2*dca_k).
|
| 3349 |
-
analysis: optional DCAAnalysis to deposit grn_depth_weights.
|
| 3350 |
-
Returns:
|
| 3351 |
-
Tuple of num_outputs tensors each [B, S, D].
|
| 3352 |
-
When num_outputs=1 returns a single [B, S, D] tensor directly.
|
| 3353 |
-
"""
|
| 3354 |
-
y, B, S, D = stack.shape
|
| 3355 |
-
assert y == self.num_stack_layers, (
|
| 3356 |
-
f"NeoLLMGRN expected stack depth {self.num_stack_layers}, got {y}"
|
| 3357 |
-
)
|
| 3358 |
-
|
| 3359 |
-
# Dynamic aggregate: [y, B, S, D] β norm β [y, B, S, D]
|
| 3360 |
-
# β to_dynamic β [y, B, S, num_outputs]
|
| 3361 |
-
# β permute β [num_outputs, y, B, S]
|
| 3362 |
-
normed = self.norm_noscale(stack) # [y, B, S, D]
|
| 3363 |
-
dynamic = self.to_dynamic(normed) # [y, B, S, num_outputs]
|
| 3364 |
-
dynamic = dynamic.permute(3, 0, 1, 2) # [o, y, B, S]
|
| 3365 |
-
|
| 3366 |
-
if analysis is not None:
|
| 3367 |
-
analysis.grn_depth_weights = dynamic.detach()
|
| 3368 |
-
|
| 3369 |
-
# Combined weight: dynamic scalar + static bias per dimension
|
| 3370 |
-
# dynamic: [o, y, B, S] β [o, y, B, S, 1]
|
| 3371 |
-
# bias: [o, y, D] β [o, y, 1, 1, D]
|
| 3372 |
-
weights = dynamic.unsqueeze(-1) + self.bias.unsqueeze(2).unsqueeze(3)
|
| 3373 |
-
# weights: [o, y, B, S, D]
|
| 3374 |
-
|
| 3375 |
-
# Weighted depth-sum: Ξ£_y stack[y] * weights[o, y]
|
| 3376 |
-
# stack: [y, B, S, D] β [1, y, B, S, D]
|
| 3377 |
-
output = (stack.unsqueeze(0) * weights).sum(dim=1) # [o, B, S, D]
|
| 3378 |
-
|
| 3379 |
-
if self.num_outputs == 1:
|
| 3380 |
-
return output.squeeze(0) # [B, S, D]
|
| 3381 |
-
return tuple(output[i] for i in range(self.num_outputs))
|
| 3382 |
-
|
| 3383 |
-
|
| 3384 |
-
class StackMemory(nn.Module):
|
| 3385 |
-
"""
|
| 3386 |
-
Differentiable multi-head hidden-state stack for NeoLLM.
|
| 3387 |
-
|
| 3388 |
-
Implements the StackTrans module from Zhang et al. (NeurIPS 2025):
|
| 3389 |
-
"Recursive Transformer: Boosting Reasoning Ability with State Stack."
|
| 3390 |
-
|
| 3391 |
-
Architecture (one forward call, covering the full sequence in parallel):
|
| 3392 |
-
|
| 3393 |
-
1. down_proj : [B,S,D] β [B,S,stack_d_model]
|
| 3394 |
-
2. action_head: β [B,S,H,3] softmax (push / pop / no-op)
|
| 3395 |
-
3. k_values : reshape to [B,S,H,ds]
|
| 3396 |
-
4. _vectorized_update: applies soft push/pop/no-op to each
|
| 3397 |
-
(batch, head) stack in parallel across the sequence dim.
|
| 3398 |
-
This is the training-parallelism approximation from Β§3.3:
|
| 3399 |
-
every token sees the *same* initial stack, breaking strict
|
| 3400 |
-
temporal ordering within a sequence in exchange for full
|
| 3401 |
-
data-parallelism. Cross-token memory is recovered during
|
| 3402 |
-
autoregressive generation via the step() / enable_cache path.
|
| 3403 |
-
5. gate_proj : global read β softmax over all stack slots
|
| 3404 |
-
(paper Β§3.1: "query-over-stack attention"), masked by the
|
| 3405 |
-
validity mask. Returns weighted sum of the stack.
|
| 3406 |
-
6. up_proj : [B,S,stack_d_model] β [B,S,D]
|
| 3407 |
-
7. residual : output = up_proj_out * res_weight + hidden_states
|
| 3408 |
-
|
| 3409 |
-
Vertical passing (layer-to-layer):
|
| 3410 |
-
Returns new_stack[:, -1] and new_mask[:, -1] β the stack state
|
| 3411 |
-
at the last sequence position β which becomes the initial stack
|
| 3412 |
-
for the next decoder layer. This propagates hierarchical context
|
| 3413 |
-
depth-wise through the network.
|
| 3414 |
-
|
| 3415 |
-
Temporal accumulation (generation):
|
| 3416 |
-
During autoregressive decoding, enable_cache=True and step() is
|
| 3417 |
-
used: k_cache and action_cache store previous-token values so the
|
| 3418 |
-
update equation integrates the full generated history rather than
|
| 3419 |
-
starting from zeros each step.
|
| 3420 |
-
|
| 3421 |
-
Args:
|
| 3422 |
-
config: NeoLLMConfig instance. Reads:
|
| 3423 |
-
stacktrans_num_heads (H, number of stack heads)
|
| 3424 |
-
stacktrans_stack_slots (S, stack depth)
|
| 3425 |
-
stacktrans_stack_d_model (HΓds, low-rank dimension)
|
| 3426 |
-
stacktrans_forward_bs (batch size for cache buffers)
|
| 3427 |
-
"""
|
| 3428 |
-
|
| 3429 |
-
def __init__(self, config: NeoLLMConfig):
|
| 3430 |
-
super().__init__()
|
| 3431 |
-
self.num_stack_heads = config.stacktrans_num_heads
|
| 3432 |
-
self.stack_slots = config.stacktrans_stack_slots
|
| 3433 |
-
self.stack_d_model = config.stacktrans_stack_d_model
|
| 3434 |
-
self.head_dim = self.stack_d_model // self.num_stack_heads
|
| 3435 |
-
|
| 3436 |
-
# Dimension reduction / expansion (standard nn.Linear, no multipliers β
|
| 3437 |
-
# StackMemory is architecturally independent per the paper Β§A)
|
| 3438 |
-
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
|
| 3439 |
-
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
|
| 3440 |
-
|
| 3441 |
-
# Action prediction: push / pop / no-op probabilities, one triple per head
|
| 3442 |
-
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
| 3443 |
-
|
| 3444 |
-
# Global read query: one scalar gate per stack slot per head
|
| 3445 |
-
self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
|
| 3446 |
-
|
| 3447 |
-
# Learnable residual gate (paper h'_t = g_hΒ·h_t + R_t, g_h scalar)
|
| 3448 |
-
self.res_weight = nn.Parameter(torch.ones(1))
|
| 3449 |
-
|
| 3450 |
-
# ββ Autoregressive generation cache ββββββββββββββββββββββββββββββ
|
| 3451 |
-
# k_cache and action_cache hold per-token values from previous steps
|
| 3452 |
-
# so step() can reconstruct the full sequence history. Only used when
|
| 3453 |
-
# enable_cache=True (set by NeoLLMModel.forward when use_cache=True).
|
| 3454 |
-
_fbs = getattr(config, "stacktrans_forward_bs", 1)
|
| 3455 |
-
_cs = getattr(config, "cache_size", 2048)
|
| 3456 |
-
self.register_buffer(
|
| 3457 |
-
"k_cache",
|
| 3458 |
-
torch.zeros(_fbs, _cs, self.num_stack_heads, self.head_dim),
|
| 3459 |
-
)
|
| 3460 |
-
self.register_buffer(
|
| 3461 |
-
"action_cache",
|
| 3462 |
-
torch.zeros(_fbs, _cs, self.num_stack_heads, 3),
|
| 3463 |
-
)
|
| 3464 |
-
self.cache_position = 0
|
| 3465 |
-
self.enable_cache = False
|
| 3466 |
-
|
| 3467 |
-
# ββ Cache helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3468 |
-
|
| 3469 |
-
def reset_cache(self) -> None:
|
| 3470 |
-
self.cache_position = 0
|
| 3471 |
-
|
| 3472 |
-
def _update_cache(
|
| 3473 |
-
self,
|
| 3474 |
-
k_values: torch.Tensor, # [B,S,H,ds] detached
|
| 3475 |
-
actions: torch.Tensor, # [B,S,H,3] detached
|
| 3476 |
-
) -> None:
|
| 3477 |
-
seq_len = k_values.shape[1]
|
| 3478 |
-
if self.cache_position + seq_len <= self.k_cache.shape[1]:
|
| 3479 |
-
self.k_cache [:, self.cache_position:self.cache_position + seq_len] = k_values
|
| 3480 |
-
self.action_cache[:, self.cache_position:self.cache_position + seq_len] = actions
|
| 3481 |
-
self.cache_position += seq_len
|
| 3482 |
-
else:
|
| 3483 |
-
self.reset_cache()
|
| 3484 |
-
|
| 3485 |
-
# ββ Core stack update βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3486 |
-
|
| 3487 |
-
def _vectorized_update(
|
| 3488 |
-
self,
|
| 3489 |
-
stack: torch.Tensor, # [B, H, slots, ds] (4-D) or [B,S,H,slots,ds] (5-D)
|
| 3490 |
-
mask: torch.Tensor, # [B, H, slots] (3-D) or [B,S,H,slots] (4-D)
|
| 3491 |
-
actions: torch.Tensor, # [B, S, H, 3]
|
| 3492 |
-
k_values: torch.Tensor, # [B, S, H, ds]
|
| 3493 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 3494 |
-
"""
|
| 3495 |
-
Vectorized soft push/pop/no-op stack update.
|
| 3496 |
-
|
| 3497 |
-
Every token position receives the *same* initial stack (the one
|
| 3498 |
-
passed in from the previous layer), and operations are applied in
|
| 3499 |
-
parallel across S. This is the Β§3.3 training-parallelism
|
| 3500 |
-
approximation: strict sequential dependency within a sequence is
|
| 3501 |
-
broken intentionally to allow full batch processing.
|
| 3502 |
-
|
| 3503 |
-
Returns:
|
| 3504 |
-
new_stack [B, S, H, slots, ds]
|
| 3505 |
-
new_mask [B, S, H, slots]
|
| 3506 |
-
"""
|
| 3507 |
-
batch_size, seq_len = actions.shape[:2]
|
| 3508 |
-
|
| 3509 |
-
# Broadcast 4-D initial state along the sequence dimension
|
| 3510 |
-
if stack.dim() == 4:
|
| 3511 |
-
stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
|
| 3512 |
-
mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
|
| 3513 |
-
|
| 3514 |
-
# Push: new value at top, shift everything down (overflow discarded)
|
| 3515 |
-
push_stack = torch.cat([k_values.unsqueeze(3), stack[:, :, :, :-1]], dim=3)
|
| 3516 |
-
push_mask = torch.cat([torch.ones_like(mask[:, :, :, :1]),
|
| 3517 |
-
mask[:, :, :, :-1]], dim=3)
|
| 3518 |
-
|
| 3519 |
-
# Pop: shift everything up, zero at bottom
|
| 3520 |
-
pop_stack = torch.cat([stack[:, :, :, 1:],
|
| 3521 |
-
torch.zeros_like(stack[:, :, :, :1])], dim=3)
|
| 3522 |
-
pop_mask = torch.cat([mask[:, :, :, 1:],
|
| 3523 |
-
torch.zeros_like(mask[:, :, :, :1])], dim=3)
|
| 3524 |
-
|
| 3525 |
-
# Soft combination weighted by action probabilities
|
| 3526 |
-
# actions: [B,S,H,3] β unsqueeze to [B,S,H,3,1,1] for stack broadcast
|
| 3527 |
-
aw = actions.unsqueeze(-1).unsqueeze(-1) # [B,S,H,3,1,1]
|
| 3528 |
-
stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [B,S,H,3,slots,ds]
|
| 3529 |
-
masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [B,S,H,3,slots]
|
| 3530 |
-
|
| 3531 |
-
new_stack = (stacks * aw).sum(dim=3) # [B,S,H,slots,ds]
|
| 3532 |
-
new_mask = (masks * aw.squeeze(-1)).sum(dim=3) # [B,S,H,slots]
|
| 3533 |
-
return new_stack, new_mask
|
| 3534 |
-
|
| 3535 |
-
# ββ Training forward (full sequence) βββββββββββββββββββββββββββββββββ
|
| 3536 |
-
|
| 3537 |
-
def forward(
|
| 3538 |
-
self,
|
| 3539 |
-
hidden_states: torch.Tensor,
|
| 3540 |
-
stack: Optional[torch.Tensor] = None,
|
| 3541 |
-
mask: Optional[torch.Tensor] = None,
|
| 3542 |
-
analysis: Optional[StackMemoryAnalysis] = None,
|
| 3543 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 3544 |
-
"""
|
| 3545 |
-
Full-sequence forward pass (training and prefill).
|
| 3546 |
-
|
| 3547 |
-
Args:
|
| 3548 |
-
hidden_states: [B, S, D]
|
| 3549 |
-
stack: [B, H, slots, ds] β previous layer's stack state,
|
| 3550 |
-
or None (initialised to zeros for layer 0).
|
| 3551 |
-
mask: [B, H, slots] β validity mask for stack,
|
| 3552 |
-
or None (initialised to zeros for layer 0).
|
| 3553 |
-
analysis: StackMemoryAnalysis container; populated when
|
| 3554 |
-
model is in eval + analysis mode.
|
| 3555 |
-
|
| 3556 |
-
Returns:
|
| 3557 |
-
(output, new_stack, new_mask)
|
| 3558 |
-
output [B, S, D]
|
| 3559 |
-
new_stack [B, H, slots, ds] β stack at final sequence position
|
| 3560 |
-
new_mask [B, H, slots]
|
| 3561 |
-
"""
|
| 3562 |
-
batch_size, seq_len, _ = hidden_states.shape
|
| 3563 |
-
device = hidden_states.device
|
| 3564 |
-
|
| 3565 |
-
# Capture incoming stack for analysis before it is updated
|
| 3566 |
-
if analysis is not None:
|
| 3567 |
-
analysis.stack_in = stack.detach() if stack is not None else None
|
| 3568 |
-
|
| 3569 |
-
# Initialise empty stack / mask for layer 0
|
| 3570 |
-
if stack is None:
|
| 3571 |
-
stack = torch.zeros(
|
| 3572 |
-
batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
|
| 3573 |
-
device=device, dtype=hidden_states.dtype,
|
| 3574 |
-
)
|
| 3575 |
-
if mask is None:
|
| 3576 |
-
mask = torch.zeros(
|
| 3577 |
-
batch_size, self.num_stack_heads, self.stack_slots,
|
| 3578 |
-
device=device, dtype=hidden_states.dtype,
|
| 3579 |
-
)
|
| 3580 |
-
|
| 3581 |
-
# 1. Project down
|
| 3582 |
-
h_proj = self.down_proj(hidden_states) # [B,S,stack_d_model]
|
| 3583 |
-
|
| 3584 |
-
# 2. Action probabilities
|
| 3585 |
-
action_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
|
| 3586 |
-
actions = F.softmax(
|
| 3587 |
-
action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), dim=-1
|
| 3588 |
-
) # [B,S,H,3]
|
| 3589 |
-
|
| 3590 |
-
# 3. Values to push
|
| 3591 |
-
k_values = h_proj.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
|
| 3592 |
-
|
| 3593 |
-
# 4. Vectorized stack update
|
| 3594 |
-
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 3595 |
-
# new_stack: [B,S,H,slots,ds], new_mask: [B,S,H,slots]
|
| 3596 |
-
|
| 3597 |
-
# 5. Global read (query-over-stack attention, paper Β§3.1)
|
| 3598 |
-
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,S,H,slots]
|
| 3599 |
-
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
|
| 3600 |
-
memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 3601 |
-
# memory_out: [B,S,H,ds] β [B,S,stack_d_model]
|
| 3602 |
-
memory_out = memory_out.view(batch_size, seq_len, self.stack_d_model)
|
| 3603 |
-
|
| 3604 |
-
# 6. Project back up
|
| 3605 |
-
memory_out_proj = self.up_proj(memory_out) # [B,S,D]
|
| 3606 |
-
|
| 3607 |
-
# 7. Residual
|
| 3608 |
-
output = memory_out_proj * self.res_weight + hidden_states
|
| 3609 |
-
|
| 3610 |
-
# 8. Update generation cache (no-op during training)
|
| 3611 |
-
if self.enable_cache:
|
| 3612 |
-
self._update_cache(k_values.detach(), actions.detach())
|
| 3613 |
-
|
| 3614 |
-
# Populate analysis fields
|
| 3615 |
-
if analysis is not None:
|
| 3616 |
-
analysis.action_probs = actions.detach()
|
| 3617 |
-
analysis.stack_out = new_stack[:, -1].detach()
|
| 3618 |
-
analysis.mask_out = new_mask[:, -1].detach()
|
| 3619 |
-
analysis.gate_weights = gate_weights.detach()
|
| 3620 |
-
analysis.memory_output = memory_out.detach()
|
| 3621 |
-
analysis.residual_scale = self.res_weight.item()
|
| 3622 |
-
|
| 3623 |
-
# Return output + last-position stack state for next layer
|
| 3624 |
-
return output, new_stack[:, -1], new_mask[:, -1]
|
| 3625 |
-
|
| 3626 |
-
# ββ Autoregressive single-token forward ββββββββββββββββββββββββββββββ
|
| 3627 |
-
|
| 3628 |
-
def step(
|
| 3629 |
-
self,
|
| 3630 |
-
hidden_state: torch.Tensor, # [B, D]
|
| 3631 |
-
stack: torch.Tensor, # [B, H, slots, ds]
|
| 3632 |
-
mask: torch.Tensor, # [B, H, slots]
|
| 3633 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 3634 |
-
"""
|
| 3635 |
-
Single-token forward for autoregressive generation.
|
| 3636 |
-
|
| 3637 |
-
When enable_cache=False (simple path used by NeoLLM generation):
|
| 3638 |
-
Calls forward() with a length-1 sequence and unpacks the result.
|
| 3639 |
-
The stack state passed in carries all history from previous tokens
|
| 3640 |
-
(propagated by NeoLLMModel.forward across generation steps).
|
| 3641 |
-
|
| 3642 |
-
When enable_cache=True (full-history reconstruction path):
|
| 3643 |
-
Concatenates the current token with cached previous-token values
|
| 3644 |
-
and replays the full vectorized update, extracting only the last
|
| 3645 |
-
position. This gives a more accurate stack that sees full history
|
| 3646 |
-
at the cost of O(T) computation per step.
|
| 3647 |
-
|
| 3648 |
-
Returns:
|
| 3649 |
-
(output, new_stack, new_mask)
|
| 3650 |
-
output [B, D]
|
| 3651 |
-
new_stack [B, H, slots, ds]
|
| 3652 |
-
new_mask [B, H, slots]
|
| 3653 |
-
"""
|
| 3654 |
-
if not self.enable_cache:
|
| 3655 |
-
# Simple path: forward with seq_len=1, squeeze the sequence dim
|
| 3656 |
-
out, new_stack, new_mask = self.forward(
|
| 3657 |
-
hidden_state.unsqueeze(1), stack, mask
|
| 3658 |
-
)
|
| 3659 |
-
return out.squeeze(1), new_stack, new_mask
|
| 3660 |
-
|
| 3661 |
-
batch_size = hidden_state.shape[0]
|
| 3662 |
-
|
| 3663 |
-
# Compute features for the current token
|
| 3664 |
-
h_proj = self.down_proj(hidden_state) # [B, stack_d_model]
|
| 3665 |
-
a_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
|
| 3666 |
-
cur_act = F.softmax(
|
| 3667 |
-
a_logits.view(batch_size, 1, self.num_stack_heads, 3), dim=-1
|
| 3668 |
-
) # [B,1,H,3]
|
| 3669 |
-
cur_k = h_proj.view(batch_size, 1, self.num_stack_heads, self.head_dim)
|
| 3670 |
-
|
| 3671 |
-
# Prepend cached history (all previous tokens in this generation)
|
| 3672 |
-
if self.cache_position > 0:
|
| 3673 |
-
k_values = torch.cat([self.k_cache[:batch_size, :self.cache_position], cur_k], dim=1)
|
| 3674 |
-
actions = torch.cat([self.action_cache[:batch_size, :self.cache_position], cur_act], dim=1)
|
| 3675 |
-
else:
|
| 3676 |
-
k_values = cur_k
|
| 3677 |
-
actions = cur_act
|
| 3678 |
-
|
| 3679 |
-
# Full vectorized update over history + current token; take last position
|
| 3680 |
-
new_stack_seq, new_mask_seq = self._vectorized_update(stack, mask, actions, k_values)
|
| 3681 |
-
new_stack = new_stack_seq[:, -1] # [B,H,slots,ds]
|
| 3682 |
-
new_mask = new_mask_seq[:, -1] # [B,H,slots]
|
| 3683 |
-
|
| 3684 |
-
# Global read on the new stack state
|
| 3685 |
-
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,H,slots]
|
| 3686 |
-
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
|
| 3687 |
-
memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
|
| 3688 |
-
memory_out = memory_out.view(batch_size, self.stack_d_model)
|
| 3689 |
-
|
| 3690 |
-
memory_out_proj = self.up_proj(memory_out) # [B,D]
|
| 3691 |
-
output = memory_out_proj * self.res_weight + hidden_state
|
| 3692 |
-
|
| 3693 |
-
self._update_cache(cur_k, cur_act)
|
| 3694 |
-
return output, new_stack, new_mask
|
| 3695 |
-
|
| 3696 |
-
|
| 3697 |
-
@dataclass
|
| 3698 |
-
class LAuReLLayer(nn.Module):
|
| 3699 |
-
"""
|
| 3700 |
-
LAuReL: Learned Augmented Residual Layer.
|
| 3701 |
-
|
| 3702 |
-
A lightweight replacement for the canonical residual connection
|
| 3703 |
-
that learns to blend the nonlinear sub-layer output f(x) with a
|
| 3704 |
-
richer linear function of the residual x, optionally augmented by a
|
| 3705 |
-
low-rank transformation.
|
| 3706 |
-
|
| 3707 |
-
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 3708 |
-
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 3709 |
-
|
| 3710 |
-
ββ Sub-variants ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3711 |
-
Controlled by config flags; any combination is valid:
|
| 3712 |
-
|
| 3713 |
-
**RW only** (use_laurel_rw=True, use_laurel_lr=False):
|
| 3714 |
-
|
| 3715 |
-
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· x_i
|
| 3716 |
-
[Ξ±, Ξ²] = softmax([a, b]), a,b β β (2 params)
|
| 3717 |
-
|
| 3718 |
-
**LR only** (use_laurel_rw=False, use_laurel_lr=True):
|
| 3719 |
-
|
| 3720 |
-
x_{i+1} = f(x_i) + AΒ·(BΒ·x_i) + x_i
|
| 3721 |
-
B β β^{rΓD} column-orthogonal init (down-projection)
|
| 3722 |
-
A β β^{DΓr} zero init (up-projection)
|
| 3723 |
-
Params: 2Β·rΒ·D per layer.
|
| 3724 |
-
|
| 3725 |
-
**RW + LR** (both True, paper recommendation):
|
| 3726 |
-
|
| 3727 |
-
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 3728 |
-
|
| 3729 |
-
ββ Initialisation ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3730 |
-
RW: raw logits [a, b] = [0, 0] β Ξ±=Ξ²=0.5 at step 0.
|
| 3731 |
-
LR: A (up) = zeros β lr_term = 0 at step 0 β pure residual at init.
|
| 3732 |
-
This ensures the model starts as a standard residual and smoothly
|
| 3733 |
-
diverges as the gates and low-rank matrices are trained.
|
| 3734 |
-
|
| 3735 |
-
ββ Integration in NeoLLM βββββββββββββββββββββββββββββββββββββββββββ
|
| 3736 |
-
Applied immediately before GPAS at both residual sums per layer:
|
| 3737 |
-
|
| 3738 |
-
h_tilde = GPAS( LAuReL(attn_out, residual_attn) )
|
| 3739 |
-
output = GPAS( LAuReL(delta_m, residual_mlp) )
|
| 3740 |
-
|
| 3741 |
-
GPAS then applies its stop-gradient scaling on the combined stream,
|
| 3742 |
-
preserving gradient magnitudes across the depth of the network.
|
| 3743 |
-
The two techniques are structurally orthogonal: LAuReL controls the
|
| 3744 |
-
*mixing ratio* of f(x) and x at each residual junction; GPAS
|
| 3745 |
-
controls the *magnitude* of the combined stream with a learned gate
|
| 3746 |
-
and a stop-gradient operator that prevents gradient vanishing.
|
| 3747 |
-
|
| 3748 |
-
Args:
|
| 3749 |
-
config: NeoLLMConfig. Reads use_laurel_rw, use_laurel_lr,
|
| 3750 |
-
laurel_lr_rank, hidden_size.
|
| 3751 |
-
"""
|
| 3752 |
-
|
| 3753 |
-
def __init__(self, config: NeoLLMConfig):
|
| 3754 |
-
super().__init__()
|
| 3755 |
-
self.use_rw = getattr(config, "use_laurel_rw", True)
|
| 3756 |
-
self.use_lr = getattr(config, "use_laurel_lr", True)
|
| 3757 |
-
D = config.hidden_size
|
| 3758 |
-
r = getattr(config, "laurel_lr_rank", 32)
|
| 3759 |
-
|
| 3760 |
-
if self.use_rw:
|
| 3761 |
-
# Raw logits for softmax([Ξ±, Ξ²]).
|
| 3762 |
-
# Stored as a single 2-vector so softmax is one op.
|
| 3763 |
-
# Init to zero β Ξ±=Ξ²=0.5 at step 0.
|
| 3764 |
-
self.rw_logits = nn.Parameter(torch.zeros(2))
|
| 3765 |
-
|
| 3766 |
-
if self.use_lr:
|
| 3767 |
-
# down: B β β^{rΓD}, column-orthogonal init (paper Β§3.3 LLM recommendation)
|
| 3768 |
-
# up: A β β^{DΓr}, zero init β lr_term=0 at step 0 (LoRA-style)
|
| 3769 |
-
self.lr_down = nn.Linear(D, r, bias=False)
|
| 3770 |
-
self.lr_up = nn.Linear(r, D, bias=False)
|
| 3771 |
-
|
| 3772 |
-
def forward(
|
| 3773 |
-
self,
|
| 3774 |
-
f_out: torch.Tensor, # output of f(x): attn or MLP [B,S,D]
|
| 3775 |
-
x_res: torch.Tensor, # residual (skip connection) [B,S,D]
|
| 3776 |
-
analysis: Optional[LAuReLAnalysis] = None,
|
| 3777 |
-
) -> torch.Tensor:
|
| 3778 |
-
"""
|
| 3779 |
-
Args:
|
| 3780 |
-
f_out: Output of f(x) β attention output or MLP delta.
|
| 3781 |
-
x_res: Residual tensor β accumulated hidden state.
|
| 3782 |
-
analysis: Optional analysis container; populated in eval+analysis mode.
|
| 3783 |
-
|
| 3784 |
-
Returns:
|
| 3785 |
-
Combined tensor [B, S, D] to be fed into GPAS.
|
| 3786 |
-
"""
|
| 3787 |
-
# ββ LR component: AΒ·(BΒ·x_res) ββββββββββββββββββββββββββββββββββββ
|
| 3788 |
-
lr_term = None
|
| 3789 |
-
if self.use_lr:
|
| 3790 |
-
lr_term = self.lr_up(self.lr_down(x_res)) # [B,S,D]
|
| 3791 |
-
g_res = lr_term + x_res # enriched residual
|
| 3792 |
-
else:
|
| 3793 |
-
g_res = x_res
|
| 3794 |
-
|
| 3795 |
-
# ββ RW component: Ξ±Β·f + Ξ²Β·g ββββββββββββββββββββββββββββββββββββββ
|
| 3796 |
-
if self.use_rw:
|
| 3797 |
-
weights = torch.softmax(self.rw_logits, dim=0) # [2]
|
| 3798 |
-
alpha = weights[0]
|
| 3799 |
-
beta = weights[1]
|
| 3800 |
-
out = alpha * f_out + beta * g_res
|
| 3801 |
-
else:
|
| 3802 |
-
# LR only: standard sum with enriched residual
|
| 3803 |
-
out = f_out + g_res
|
| 3804 |
-
|
| 3805 |
-
if analysis is not None:
|
| 3806 |
-
if self.use_rw:
|
| 3807 |
-
analysis.alpha_rw = alpha.item()
|
| 3808 |
-
analysis.beta_rw = beta.item()
|
| 3809 |
-
if self.use_lr:
|
| 3810 |
-
analysis.lr_term = lr_term.detach()
|
| 3811 |
-
analysis.output = out.detach()
|
| 3812 |
-
|
| 3813 |
-
return out
|
| 3814 |
-
|
| 3815 |
-
|
| 3816 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 3817 |
"""
|
| 3818 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
@@ -3868,78 +3120,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3868 |
self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3869 |
self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3870 |
self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 3871 |
-
_num_blocks = getattr(config, 'attn_res_num_blocks', 0)
|
| 3872 |
-
self.attn_res_block_size = (
|
| 3873 |
-
max(config.num_hidden_layers // _num_blocks, 1) if _num_blocks > 0 else 1
|
| 3874 |
-
)
|
| 3875 |
else:
|
| 3876 |
self.attn_res_query_attn = None
|
| 3877 |
self.attn_res_query_mlp = None
|
| 3878 |
self.attn_res_norm = None
|
| 3879 |
-
self.attn_res_block_size = None
|
| 3880 |
-
|
| 3881 |
-
# ββ MUDD: separate K/V LayerNorms for qkvr+sepln mode ββββββββββββββ
|
| 3882 |
-
# Only instantiated when both mudd_dense_type='qkvr' AND mudd_sepln=True.
|
| 3883 |
-
# The existing input_layernorm handles the Q stream (unchanged).
|
| 3884 |
-
# Separate norms for K and V allow each stream to rescale independently.
|
| 3885 |
-
_use_mudd = getattr(config, 'use_mudd', False)
|
| 3886 |
-
_mudd_qkvr = getattr(config, 'mudd_dense_type', 'qkvr') == 'qkvr'
|
| 3887 |
-
_mudd_sepln = getattr(config, 'mudd_sepln', False)
|
| 3888 |
-
if _use_mudd and _mudd_qkvr and _mudd_sepln:
|
| 3889 |
-
self.mudd_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 3890 |
-
self.mudd_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 3891 |
-
else:
|
| 3892 |
-
self.mudd_k_norm = None
|
| 3893 |
-
self.mudd_v_norm = None
|
| 3894 |
-
|
| 3895 |
-
# ββ DCA (Heddes et al., 2025, arXiv:2502.06785) βββββββββββββββββββ
|
| 3896 |
-
# GRN-v3 module that aggregates the k-selected depth stack into 3
|
| 3897 |
-
# independent streams (Q, K, V). Each stream has its own dimension-
|
| 3898 |
-
# and input-dependent weights, enabling richer cross-layer interactions.
|
| 3899 |
-
# K and V get their own SeeDNorm + LNS norm chain (same scheme as
|
| 3900 |
-
# MUDD sepln) since they now arrive from a different aggregation path.
|
| 3901 |
-
# The residual connection uses the Q stream output (xq) as its base,
|
| 3902 |
-
# matching the DCA paper's decoder block design (residual = q_input).
|
| 3903 |
-
self.use_dca = getattr(config, 'use_dca', False)
|
| 3904 |
-
if self.use_dca:
|
| 3905 |
-
_dca_k = getattr(config, 'dca_k', 2)
|
| 3906 |
-
_num_stack = min(layer_idx + 1, 2 * _dca_k)
|
| 3907 |
-
self.dca_grn = NeoLLMGRN(
|
| 3908 |
-
hidden_size = config.hidden_size,
|
| 3909 |
-
num_stack_layers = _num_stack,
|
| 3910 |
-
num_outputs = 3,
|
| 3911 |
-
eps = config.rms_norm_eps,
|
| 3912 |
-
)
|
| 3913 |
-
self.dca_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 3914 |
-
self.dca_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 3915 |
-
else:
|
| 3916 |
-
self.dca_grn = None
|
| 3917 |
-
self.dca_k_norm = None
|
| 3918 |
-
self.dca_v_norm = None
|
| 3919 |
-
|
| 3920 |
-
# ββ StackTrans (Zhang et al., NeurIPS 2025) βββββββββββββββββββββββ
|
| 3921 |
-
# Differentiable multi-head hidden-state stack inserted at the very
|
| 3922 |
-
# beginning of the layer forward, before the attention sublayer.
|
| 3923 |
-
# Mutually exclusive with use_attn_res, use_mudd, use_dca.
|
| 3924 |
-
self.use_stacktrans = getattr(config, 'use_stacktrans', False)
|
| 3925 |
-
if self.use_stacktrans:
|
| 3926 |
-
self.stack_memory = StackMemory(config)
|
| 3927 |
-
else:
|
| 3928 |
-
self.stack_memory = None
|
| 3929 |
-
|
| 3930 |
-
# ββ LAuReL (Menghani, Kumar & Kumar, ICML 2025) βββββββββββββββββββ
|
| 3931 |
-
# Learned augmented residual connection replacing f(x)+x at both
|
| 3932 |
-
# the attention and MLP residual sums. Applied immediately before
|
| 3933 |
-
# GPAS, so GPAS still controls magnitude via stop-gradient scaling.
|
| 3934 |
-
# Two independent instances per layer (attention and MLP).
|
| 3935 |
-
# Compatible with use_stacktrans. Incompatible with MUDD/DCA/AttnRes.
|
| 3936 |
-
self.use_laurel = getattr(config, 'use_laurel', False)
|
| 3937 |
-
if self.use_laurel:
|
| 3938 |
-
self.laurel_attn = LAuReLLayer(config)
|
| 3939 |
-
self.laurel_mlp = LAuReLLayer(config)
|
| 3940 |
-
else:
|
| 3941 |
-
self.laurel_attn = None
|
| 3942 |
-
self.laurel_mlp = None
|
| 3943 |
|
| 3944 |
def _attn_res(
|
| 3945 |
self,
|
|
@@ -3989,10 +3173,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 3989 |
B_vals: Optional[torch.Tensor] = None,
|
| 3990 |
attn_res_sources: Optional[list] = None,
|
| 3991 |
attn_res_partial: Optional[torch.Tensor] = None,
|
| 3992 |
-
mudd_streams: Optional[tuple] = None,
|
| 3993 |
-
dca_stack: Optional[torch.Tensor] = None,
|
| 3994 |
-
stack_state: Optional[torch.Tensor] = None,
|
| 3995 |
-
stack_mask: Optional[torch.Tensor] = None,
|
| 3996 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 3997 |
output_attentions: Optional[bool] = False,
|
| 3998 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
@@ -4002,63 +3182,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4002 |
if layer_analysis is not None:
|
| 4003 |
layer_analysis.hidden_states_input = hidden_states.detach()
|
| 4004 |
|
| 4005 |
-
# ββ StackTrans: hidden-state stack (pre-attention, pre-norm) βββββ
|
| 4006 |
-
# Executed first so attention sees the stack-enriched representation.
|
| 4007 |
-
# stack_state / stack_mask carry the stack from the previous layer;
|
| 4008 |
-
# both are None for layer 0 (StackMemory initialises to zeros then).
|
| 4009 |
-
# Mutually exclusive with MUDD / DCA / AttnRes β those branches are
|
| 4010 |
-
# all skipped when use_stacktrans=True (enforced in NeoLLMConfig).
|
| 4011 |
-
if self.use_stacktrans and self.stack_memory is not None:
|
| 4012 |
-
st_analysis = layer_analysis.stack if layer_analysis is not None else None
|
| 4013 |
-
hidden_states, stack_state, stack_mask = self.stack_memory(
|
| 4014 |
-
hidden_states, stack_state, stack_mask, analysis=st_analysis
|
| 4015 |
-
)
|
| 4016 |
-
|
| 4017 |
-
# ββ MUDD: unpack streams for Q/K/V/R (layer > 0 only) ββββββββββββ
|
| 4018 |
-
# mudd_streams is a 4-tuple (xq, xk, xv, xr) when use_mudd=True and
|
| 4019 |
-
# layer_idx > 0; None for layer 0 (standard residual there).
|
| 4020 |
-
# xr replaces hidden_states as the residual throughout this layer.
|
| 4021 |
-
# xq/xk/xv are the aggregated inputs for Q, K, V projections.
|
| 4022 |
-
# When mudd_dense_type='l' (single stream), all four are equal.
|
| 4023 |
-
# When mudd_sepln=True each stream has its own norm applied below.
|
| 4024 |
-
mudd_xk = None
|
| 4025 |
-
mudd_xv = None
|
| 4026 |
-
if mudd_streams is not None:
|
| 4027 |
-
xq_mudd, xk_mudd, xv_mudd, xr_mudd = mudd_streams
|
| 4028 |
-
# Replace hidden_states with xr for residual connections
|
| 4029 |
-
hidden_states = xr_mudd
|
| 4030 |
-
# Norm K and V streams β use separate SeeDNorm if sepln, else
|
| 4031 |
-
# they will share the main input_layernorm path via h_attn below
|
| 4032 |
-
if self.mudd_k_norm is not None:
|
| 4033 |
-
mudd_xk = self.lns_attn(self.mudd_k_norm(xk_mudd))
|
| 4034 |
-
mudd_xv = self.lns_attn(self.mudd_v_norm(xv_mudd))
|
| 4035 |
-
else:
|
| 4036 |
-
# No sepln: K/V also go through the Q-path norm chain
|
| 4037 |
-
mudd_xk = self.lns_attn(self.input_layernorm(xk_mudd))
|
| 4038 |
-
mudd_xv = self.lns_attn(self.input_layernorm(xv_mudd))
|
| 4039 |
-
# Override hidden_states for the Q path
|
| 4040 |
-
hidden_states_for_attn = xq_mudd
|
| 4041 |
-
else:
|
| 4042 |
-
hidden_states_for_attn = hidden_states
|
| 4043 |
-
|
| 4044 |
-
# ββ DCA: GRN-v3 depth-wise aggregation βββββββββββββββββββββββββββ
|
| 4045 |
-
# When active, runs the per-layer GRN on the k-selected depth stack
|
| 4046 |
-
# to produce three independent aggregated streams (Q, K, V).
|
| 4047 |
-
# xq replaces hidden_states as both the Q projection input AND the
|
| 4048 |
-
# post-attention residual (DCA paper: residual = q_input).
|
| 4049 |
-
# xk and xv go through separate SeeDNorm+LNS chains and are injected
|
| 4050 |
-
# into NeoLLMAttention via the existing mudd_xk/mudd_xv parameters.
|
| 4051 |
-
dca_residual = None
|
| 4052 |
-
dca_a = layer_analysis.dca if layer_analysis is not None else None
|
| 4053 |
-
if self.use_dca and dca_stack is not None:
|
| 4054 |
-
xq, xk, xv = self.dca_grn(dca_stack, analysis=dca_a)
|
| 4055 |
-
dca_residual = xq
|
| 4056 |
-
hidden_states_for_attn = xq
|
| 4057 |
-
# K and V streams: SeeDNorm + LNS before k_proj / v_proj
|
| 4058 |
-
# (reuses the mudd_xk/mudd_xv injection path in NeoLLMAttention)
|
| 4059 |
-
mudd_xk = self.lns_attn(self.dca_k_norm(xk))
|
| 4060 |
-
mudd_xv = self.lns_attn(self.dca_v_norm(xv))
|
| 4061 |
-
|
| 4062 |
# ββ Attention Residuals: compute pre-attention input ββββββββββββββ
|
| 4063 |
# When active, the input to the attention sublayer is no longer the
|
| 4064 |
# raw hidden_states (accumulated residual) but a softmax-weighted
|
|
@@ -4072,19 +3195,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4072 |
attn_res_sources, attn_res_partial, self.attn_res_query_attn,
|
| 4073 |
ar_analysis, "attn",
|
| 4074 |
)
|
| 4075 |
-
# ββ Block boundary fires HERE β after pre-attn, before attn sublayer ββ
|
| 4076 |
-
# Paper pseudocode (Fig. 2) timing: the completed partial of the previous
|
| 4077 |
-
# block is pushed to sources AFTER the pre-attn AttnRes call, so the first
|
| 4078 |
-
# layer of a new block still sees the old partial as an intra-block source
|
| 4079 |
-
# (no duplicate) and the new intra-block accumulation starts from zeros.
|
| 4080 |
-
if self.layer_idx > 0 and self.layer_idx % self.attn_res_block_size == 0:
|
| 4081 |
-
attn_res_sources.append(attn_res_partial) # in-place; outer loop sees this
|
| 4082 |
-
attn_res_partial = torch.zeros_like(attn_res_partial) # fresh delta start
|
| 4083 |
residual_attn = attn_res_partial
|
| 4084 |
else:
|
| 4085 |
-
h_attn =
|
| 4086 |
-
|
| 4087 |
-
residual_attn = dca_residual if dca_residual is not None else hidden_states
|
| 4088 |
|
| 4089 |
# ββ Attention block βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4090 |
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
|
|
@@ -4100,8 +3214,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4100 |
first_layer_fan=first_layer_fan,
|
| 4101 |
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 4102 |
repo_rope_args=repo_rope_args,
|
| 4103 |
-
mudd_xk=mudd_xk,
|
| 4104 |
-
mudd_xv=mudd_xv,
|
| 4105 |
**kwargs,
|
| 4106 |
)
|
| 4107 |
|
|
@@ -4109,18 +3221,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4109 |
layer_analysis.attn_contribution = hidden_states.detach()
|
| 4110 |
|
| 4111 |
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
|
| 4112 |
-
|
| 4113 |
-
# ββ Attention residual sum ββββββββββββββββββββββββββββββββββββββββ
|
| 4114 |
-
# Standard: GPAS(residual_attn + hidden_states)
|
| 4115 |
-
# LAuReL: GPAS(LAuReL(f_out=hidden_states, x_res=residual_attn))
|
| 4116 |
-
# Both paths feed into GPAS which applies stop-gradient scaling.
|
| 4117 |
-
if self.use_laurel and self.laurel_attn is not None:
|
| 4118 |
-
la_attn_a = layer_analysis.laurel_attn if layer_analysis is not None else None
|
| 4119 |
-
combined_attn = self.laurel_attn(hidden_states, residual_attn, analysis=la_attn_a)
|
| 4120 |
-
else:
|
| 4121 |
-
combined_attn = residual_attn + hidden_states
|
| 4122 |
-
|
| 4123 |
-
h_tilde = self.gpas_attn(combined_attn, analysis=gpas_attn_a)
|
| 4124 |
|
| 4125 |
if layer_analysis is not None:
|
| 4126 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
@@ -4156,11 +3257,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4156 |
if layer_analysis is not None:
|
| 4157 |
layer_analysis.mlp_contribution = delta_m.detach()
|
| 4158 |
|
| 4159 |
-
# ββ MLP residual
|
| 4160 |
-
|
| 4161 |
-
# x_res = residual_mlp. JTok-M delta_r is additive alongside delta_m,
|
| 4162 |
-
# so the nonlinear component is delta_m + delta_r in that path.
|
| 4163 |
-
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 4164 |
if self.use_jtokm and z_tilde is not None and B_vals is not None:
|
| 4165 |
orig_shape = h_tilde.shape
|
| 4166 |
h_flat = h_tilde.reshape(-1, self.hidden_size)
|
|
@@ -4171,21 +3269,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4171 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 4172 |
delta_r = delta_r.reshape(orig_shape)
|
| 4173 |
|
| 4174 |
-
|
| 4175 |
-
|
| 4176 |
-
la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
|
| 4177 |
-
combined_mlp = self.laurel_mlp(f_mlp, residual_mlp, analysis=la_mlp_a)
|
| 4178 |
-
else:
|
| 4179 |
-
combined_mlp = residual_mlp + f_mlp
|
| 4180 |
-
hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
|
| 4181 |
else:
|
| 4182 |
-
|
| 4183 |
-
|
| 4184 |
-
la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
|
| 4185 |
-
combined_mlp = self.laurel_mlp(delta_m, residual_mlp, analysis=la_mlp_a)
|
| 4186 |
-
else:
|
| 4187 |
-
combined_mlp = residual_mlp + delta_m
|
| 4188 |
-
hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
|
| 4189 |
|
| 4190 |
if layer_analysis is not None:
|
| 4191 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
@@ -4197,9 +3285,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4197 |
outputs += (aux_stats,)
|
| 4198 |
if versatile_aux is not None:
|
| 4199 |
outputs += (versatile_aux,)
|
| 4200 |
-
# StackTrans: always append stack state (None, None when inactive)
|
| 4201 |
-
# so NeoLLMModel.forward can extract them by position -2 and -1.
|
| 4202 |
-
outputs += (stack_state, stack_mask)
|
| 4203 |
return outputs
|
| 4204 |
|
| 4205 |
|
|
@@ -4583,45 +3668,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 4583 |
module.attn_res_query_attn.data.zero_()
|
| 4584 |
module.attn_res_query_mlp.data.zero_()
|
| 4585 |
|
| 4586 |
-
elif isinstance(module, StackMemory):
|
| 4587 |
-
# Truncated-normal for all Linear weights (matches NeoLLM convention).
|
| 4588 |
-
# Biases zeroed. res_weight starts at 1.0 so the stack readout
|
| 4589 |
-
# contributes equally to the residual from step 0.
|
| 4590 |
-
std = getattr(self.config, "initializer_range", 0.02)
|
| 4591 |
-
cutoff = getattr(self.config, "init_cutoff_factor", 3.0) * std
|
| 4592 |
-
for attr in ("down_proj", "up_proj", "action_head", "gate_proj"):
|
| 4593 |
-
layer = getattr(module, attr, None)
|
| 4594 |
-
if layer is not None and hasattr(layer, "weight"):
|
| 4595 |
-
nn.init.trunc_normal_(
|
| 4596 |
-
layer.weight, mean=0.0, std=std, a=-cutoff, b=cutoff
|
| 4597 |
-
)
|
| 4598 |
-
if layer.bias is not None:
|
| 4599 |
-
nn.init.zeros_(layer.bias)
|
| 4600 |
-
if hasattr(module, "res_weight"):
|
| 4601 |
-
module.res_weight.data.fill_(1.0)
|
| 4602 |
-
|
| 4603 |
-
elif isinstance(module, LAuReLLayer):
|
| 4604 |
-
# RW: raw logits initialised to zero β softmax([0,0]) = [0.5, 0.5].
|
| 4605 |
-
# The model quickly learns the optimal Ξ±,Ξ² weighting.
|
| 4606 |
-
# LR: lr_down (B, down-projection) β column orthogonal init,
|
| 4607 |
-
# as recommended by the LAuReL paper Β§3.3 for LLMs.
|
| 4608 |
-
# Column orthogonal preserves the L2 norm of the projected
|
| 4609 |
-
# representation, ensuring stable gradient magnitudes
|
| 4610 |
-
# through the low-rank bottleneck at init.
|
| 4611 |
-
# lr_up (A, up-projection) β zero init β lr_term = AΒ·Bx = 0
|
| 4612 |
-
# at step 0, so the module starts as a standard residual.
|
| 4613 |
-
# Gradient flows back through lr_down immediately via
|
| 4614 |
-
# chain rule; A learns from step 1 onward.
|
| 4615 |
-
if hasattr(module, "rw_logits"):
|
| 4616 |
-
nn.init.zeros_(module.rw_logits)
|
| 4617 |
-
if hasattr(module, "lr_down"):
|
| 4618 |
-
# Column-orthogonal: each column of weight^T is orthonormal.
|
| 4619 |
-
# nn.init.orthogonal_ produces a row-orthogonal matrix (rows
|
| 4620 |
-
# are orthonormal). Transposing gives column-orthogonal.
|
| 4621 |
-
nn.init.orthogonal_(module.lr_down.weight)
|
| 4622 |
-
if hasattr(module, "lr_up"):
|
| 4623 |
-
nn.init.zeros_(module.lr_up.weight)
|
| 4624 |
-
|
| 4625 |
elif isinstance(module, SpellingBeeEmbedding):
|
| 4626 |
# byte_emb initialised identically to token embeddings: std=1/βd.
|
| 4627 |
# Ensures E[βe_byteβΒ²] β 1 at init, matching etok, so the
|
|
@@ -4691,82 +3737,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 4691 |
self.gradient_checkpointing = False
|
| 4692 |
self.first_layer_fan = None
|
| 4693 |
|
| 4694 |
-
# ββ StackTrans state flag βββββββββββββββββββββββββββββββββββββββββ
|
| 4695 |
-
self.use_stacktrans = getattr(config, 'use_stacktrans', False)
|
| 4696 |
-
|
| 4697 |
-
# ββ Residual-replacement mutex ββββββββββββββββββββββββββββββββββββ
|
| 4698 |
-
# AttnRes, MUDD, and DCA all replace the residual aggregation
|
| 4699 |
-
# mechanism β at most one can be active at a time.
|
| 4700 |
-
_use_mudd = getattr(config, 'use_mudd', False)
|
| 4701 |
-
_use_attn_res = getattr(config, 'use_attn_res', False)
|
| 4702 |
-
_use_dca = getattr(config, 'use_dca', False)
|
| 4703 |
-
_active_count = sum([_use_mudd, _use_attn_res, _use_dca])
|
| 4704 |
-
if _active_count > 1:
|
| 4705 |
-
active = [n for n, f in [('use_mudd', _use_mudd),
|
| 4706 |
-
('use_attn_res', _use_attn_res),
|
| 4707 |
-
('use_dca', _use_dca)] if f]
|
| 4708 |
-
raise ValueError(
|
| 4709 |
-
f"use_mudd, use_attn_res, and use_dca are mutually exclusive β "
|
| 4710 |
-
f"got {active} simultaneously active. Set exactly one to True."
|
| 4711 |
-
)
|
| 4712 |
-
if _use_mudd:
|
| 4713 |
-
_mudd_dense_type = getattr(config, 'mudd_dense_type', 'qkvr')
|
| 4714 |
-
_mudd_dynamic = getattr(config, 'mudd_dynamic_dense', True)
|
| 4715 |
-
_mudd_round64 = getattr(config, 'mudd_round64', False)
|
| 4716 |
-
_mudd_expand_last = getattr(config, 'mudd_expand_last', False)
|
| 4717 |
-
_C = 4 if _mudd_dense_type == 'qkvr' else 1
|
| 4718 |
-
|
| 4719 |
-
# Static bias: one [C, lidx+2] parameter per layer.
|
| 4720 |
-
# Initialized with 1 at index [c, lidx+1] (identity on Xi) so that
|
| 4721 |
-
# at init (W2=0) each DA output = Xi β reducing to standard Transformer.
|
| 4722 |
-
_static_list = []
|
| 4723 |
-
for lidx in range(config.num_hidden_layers):
|
| 4724 |
-
# Last layer always uses C=1: its DA output is the final
|
| 4725 |
-
# model representation fed to the norm and lm_head, collapsing
|
| 4726 |
-
# all history into a single stream (paper code, both files).
|
| 4727 |
-
_c = 1 if lidx == config.num_hidden_layers - 1 else _C
|
| 4728 |
-
a = torch.zeros(_c, lidx + 2)
|
| 4729 |
-
a[:, lidx + 1] = 1.0 # last entry = current layer = identity
|
| 4730 |
-
_static_list.append(nn.Parameter(a))
|
| 4731 |
-
self.mudd_static = nn.ParameterList(_static_list)
|
| 4732 |
-
|
| 4733 |
-
# Dynamic DA modules (one per layer)
|
| 4734 |
-
if _mudd_dynamic:
|
| 4735 |
-
self.mudd_dynamic = nn.ModuleList([
|
| 4736 |
-
NeoLLMMUDDModule(
|
| 4737 |
-
hidden_size = config.hidden_size,
|
| 4738 |
-
lidx = lidx,
|
| 4739 |
-
# Last layer: C=1 β collapses to single final repr
|
| 4740 |
-
num_ways = 1 if lidx == config.num_hidden_layers - 1 else _C,
|
| 4741 |
-
is_last = (lidx == config.num_hidden_layers - 1),
|
| 4742 |
-
expand_last = _mudd_expand_last,
|
| 4743 |
-
round64 = _mudd_round64,
|
| 4744 |
-
)
|
| 4745 |
-
for lidx in range(config.num_hidden_layers)
|
| 4746 |
-
])
|
| 4747 |
-
else:
|
| 4748 |
-
self.mudd_dynamic = None
|
| 4749 |
-
else:
|
| 4750 |
-
self.mudd_static = None
|
| 4751 |
-
self.mudd_dynamic = None
|
| 4752 |
-
|
| 4753 |
-
# ββ DCA final GRN (Heddes et al., 2025) βββββββββββββββββββββββββββ
|
| 4754 |
-
# Applied once after all decoder layers to aggregate the full depth
|
| 4755 |
-
# stack into the final hidden representation before the output norm.
|
| 4756 |
-
# num_stack_layers = min(2*k, L+1) β same cap as per-layer GRNs.
|
| 4757 |
-
# num_outputs=1 collapses to a single [B, S, D] tensor.
|
| 4758 |
-
if _use_dca and getattr(config, 'dca_use_final_grn', True):
|
| 4759 |
-
_dca_k = getattr(config, 'dca_k', 2)
|
| 4760 |
-
_dca_eps = getattr(config, 'dca_grn_eps', config.rms_norm_eps)
|
| 4761 |
-
self.dca_final_grn = NeoLLMGRN(
|
| 4762 |
-
hidden_size = config.hidden_size,
|
| 4763 |
-
num_stack_layers = min(2 * _dca_k, config.num_hidden_layers + 1),
|
| 4764 |
-
num_outputs = 1,
|
| 4765 |
-
eps = _dca_eps,
|
| 4766 |
-
)
|
| 4767 |
-
else:
|
| 4768 |
-
self.dca_final_grn = None
|
| 4769 |
-
|
| 4770 |
self.post_init()
|
| 4771 |
|
| 4772 |
def get_input_embeddings(self):
|
|
@@ -4794,9 +3764,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 4794 |
getattr(cfg, "use_repo", False)
|
| 4795 |
and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
|
| 4796 |
)
|
| 4797 |
-
_versatile
|
| 4798 |
-
_use_stacktrans = getattr(cfg, "use_stacktrans", False)
|
| 4799 |
-
_use_laurel = getattr(cfg, "use_laurel", False)
|
| 4800 |
return LayerAnalysis(
|
| 4801 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 4802 |
seednorm_post_attn = SeeDNormAnalysis(),
|
|
@@ -4810,14 +3778,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 4810 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 4811 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 4812 |
),
|
| 4813 |
-
gpas_attn
|
| 4814 |
-
gpas_mlp
|
| 4815 |
-
jtokm
|
| 4816 |
-
attn_res
|
| 4817 |
-
dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
|
| 4818 |
-
stack = StackMemoryAnalysis() if _use_stacktrans else None,
|
| 4819 |
-
laurel_attn = LAuReLAnalysis() if _use_laurel else None,
|
| 4820 |
-
laurel_mlp = LAuReLAnalysis() if _use_laurel else None,
|
| 4821 |
)
|
| 4822 |
|
| 4823 |
def forward(
|
|
@@ -4933,57 +3897,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 4933 |
if use_attn_res:
|
| 4934 |
attn_res_sources = [hidden_states] # b_0 = token embedding
|
| 4935 |
attn_res_partial = hidden_states # initial partial sum
|
| 4936 |
-
|
| 4937 |
-
|
| 4938 |
-
|
| 4939 |
-
|
| 4940 |
-
|
| 4941 |
-
|
| 4942 |
-
|
| 4943 |
-
# mudd_streams is None for layer 0 (standard residual path there)
|
| 4944 |
-
# and a C-tuple of [B,T,D] tensors for layers 1β¦L.
|
| 4945 |
-
use_mudd = getattr(self.config, 'use_mudd', False)
|
| 4946 |
-
mudd_hiddens = None
|
| 4947 |
-
mudd_streams = None
|
| 4948 |
-
if use_mudd:
|
| 4949 |
-
mudd_hiddens = [hidden_states] # b_0 = token embedding
|
| 4950 |
-
|
| 4951 |
-
# ββ DCA state βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4952 |
-
# all_tokens[0] = token embedding; grows by one per decoder layer.
|
| 4953 |
-
# Before each layer, the stack is built and k-DCA selection applied,
|
| 4954 |
-
# capping memory at 2*dca_k stored tensors regardless of depth.
|
| 4955 |
-
# dca_stack is always non-None (even layer 0 gets [embedding]).
|
| 4956 |
-
use_dca = getattr(self.config, 'use_dca', False)
|
| 4957 |
-
_dca_k = getattr(self.config, 'dca_k', 2)
|
| 4958 |
-
dca_all_tokens = None
|
| 4959 |
-
dca_stack = None
|
| 4960 |
-
if use_dca:
|
| 4961 |
-
dca_all_tokens = [hidden_states] # [embedding]
|
| 4962 |
-
|
| 4963 |
-
# ββ StackTrans state ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4964 |
-
# stack_state / stack_mask start as None for the first layer;
|
| 4965 |
-
# StackMemory initialises them to zeros internally on first call.
|
| 4966 |
-
# After each layer, the returned (new_stack, new_mask) are passed
|
| 4967 |
-
# to the next layer as its initial stack β this is "vertical" state
|
| 4968 |
-
# propagation: information flows depth-wise through the stack.
|
| 4969 |
-
#
|
| 4970 |
-
# Temporal accumulation across generation steps is handled by the
|
| 4971 |
-
# StackMemory internal k_cache / action_cache mechanism:
|
| 4972 |
-
# - enable_cache is set True when use_cache=True (inference)
|
| 4973 |
-
# - reset_cache() is called when past_key_values is None
|
| 4974 |
-
# (new sequence, not a continuation step)
|
| 4975 |
-
# This matches the OLMo reference implementation exactly.
|
| 4976 |
-
use_stacktrans = self.use_stacktrans
|
| 4977 |
-
stack_state = None
|
| 4978 |
-
stack_mask = None
|
| 4979 |
-
if use_stacktrans:
|
| 4980 |
-
use_cache_flag = kwargs.get("use_cache", False)
|
| 4981 |
-
past_kv_flag = kwargs.get("past_key_values", None)
|
| 4982 |
-
for layer in self.layers:
|
| 4983 |
-
if layer.stack_memory is not None:
|
| 4984 |
-
layer.stack_memory.enable_cache = bool(use_cache_flag)
|
| 4985 |
-
if past_kv_flag is None:
|
| 4986 |
-
layer.stack_memory.reset_cache()
|
| 4987 |
|
| 4988 |
# Pre-allocate per-layer analysis list when analysis is active
|
| 4989 |
if analysis_state is not None:
|
|
@@ -4993,13 +3913,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 4993 |
if output_hidden_states:
|
| 4994 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 4995 |
|
| 4996 |
-
# ββ
|
| 4997 |
-
#
|
| 4998 |
-
#
|
| 4999 |
-
|
| 5000 |
-
|
| 5001 |
-
|
| 5002 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5003 |
|
| 5004 |
# Build per-layer analysis container (only in eval + analysis mode)
|
| 5005 |
layer_analysis = None
|
|
@@ -5017,10 +3941,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5017 |
B_vals=B_vals,
|
| 5018 |
attn_res_sources=attn_res_sources,
|
| 5019 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
| 5020 |
-
mudd_streams=mudd_streams,
|
| 5021 |
-
dca_stack=dca_stack,
|
| 5022 |
-
stack_state=stack_state,
|
| 5023 |
-
stack_mask=stack_mask,
|
| 5024 |
layer_analysis=layer_analysis,
|
| 5025 |
output_attentions=output_attentions,
|
| 5026 |
repo_rope_args=repo_rope_args,
|
|
@@ -5028,76 +3948,23 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5028 |
)
|
| 5029 |
hidden_states = layer_outputs[0]
|
| 5030 |
|
| 5031 |
-
# ββ StackTrans: extract updated stack state for next layer βββββ
|
| 5032 |
-
# layer_outputs always ends with (stack_state, stack_mask) β
|
| 5033 |
-
# both are None when use_stacktrans=False (zero cost).
|
| 5034 |
-
stack_state = layer_outputs[-2]
|
| 5035 |
-
stack_mask = layer_outputs[-1]
|
| 5036 |
-
|
| 5037 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 5038 |
if use_attn_res:
|
| 5039 |
attn_res_partial = hidden_states
|
| 5040 |
|
| 5041 |
-
# Append layer output to DCA history for next layer's stack
|
| 5042 |
-
if use_dca:
|
| 5043 |
-
dca_all_tokens.append(hidden_states)
|
| 5044 |
-
|
| 5045 |
-
# ββ MUDD: append current output and compute DA for next layer ββ
|
| 5046 |
-
# mudd_hiddens grows by 1 each iteration; at layer i it has i+2
|
| 5047 |
-
# entries (embedding + i outputs). The DA for layer i+1 takes this
|
| 5048 |
-
# full history and produces C streams via dynamic + static weights.
|
| 5049 |
-
# mudd_streams is passed to layer i+1 as its input streams.
|
| 5050 |
-
if use_mudd:
|
| 5051 |
-
mudd_hiddens.append(hidden_states)
|
| 5052 |
-
# Compute DA module output using the just-appended history
|
| 5053 |
-
# (mudd_hiddens now has layer_idx+2 entries)
|
| 5054 |
-
is_last_layer = (layer_idx == self.config.num_hidden_layers - 1)
|
| 5055 |
-
mudd_da_module = self.mudd_dynamic[layer_idx] if self.mudd_dynamic is not None else None
|
| 5056 |
-
if mudd_da_module is not None:
|
| 5057 |
-
raw_streams = mudd_da_module(
|
| 5058 |
-
hidden_states,
|
| 5059 |
-
mudd_hiddens,
|
| 5060 |
-
self.mudd_static[layer_idx],
|
| 5061 |
-
)
|
| 5062 |
-
else:
|
| 5063 |
-
# Static-only: apply weighted sum with learnable bias
|
| 5064 |
-
# stack history [L, B, T, D], weight by mudd_static
|
| 5065 |
-
stacked = torch.stack(mudd_hiddens, dim=0) # [L, B, T, D]
|
| 5066 |
-
a = self.mudd_static[layer_idx].to(hidden_states.dtype) # [C, L]
|
| 5067 |
-
raw_streams = tuple(
|
| 5068 |
-
torch.einsum('cl,lbtd->btd', a[c:c+1], stacked).squeeze(0)
|
| 5069 |
-
for c in range(a.shape[0])
|
| 5070 |
-
)
|
| 5071 |
-
if is_last_layer:
|
| 5072 |
-
# Last layer DA always produces C=1 β single final repr.
|
| 5073 |
-
# This is the MUDD-aggregated combination of all layer
|
| 5074 |
-
# histories weighted by the last layer's output as query.
|
| 5075 |
-
# Replace hidden_states so the final norm and lm_head
|
| 5076 |
-
# receive this aggregated representation, not the raw
|
| 5077 |
-
# last-layer output (paper forward loop: x = x[0] after loop).
|
| 5078 |
-
hidden_states = raw_streams[0]
|
| 5079 |
-
mudd_streams = None # no next layer
|
| 5080 |
-
elif len(raw_streams) == 1:
|
| 5081 |
-
# dense_type='l': broadcast to 4-tuple
|
| 5082 |
-
mudd_streams = (raw_streams[0],) * 4
|
| 5083 |
-
else:
|
| 5084 |
-
# 'qkvr': 4 streams β (xq, xk, xv, xr)
|
| 5085 |
-
mudd_streams = raw_streams
|
| 5086 |
-
|
| 5087 |
if output_attentions:
|
| 5088 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 5089 |
|
| 5090 |
-
# Collect JTok-M
|
| 5091 |
-
|
| 5092 |
-
|
| 5093 |
-
inner_outputs = layer_outputs[1:-2]
|
| 5094 |
-
|
| 5095 |
-
if self.config.use_jtokm and len(inner_outputs) > (1 if output_attentions else 0):
|
| 5096 |
-
all_aux_stats.append(inner_outputs[-1])
|
| 5097 |
|
|
|
|
|
|
|
| 5098 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 5099 |
-
for item in
|
| 5100 |
if isinstance(item, tuple) and len(item) == 3:
|
|
|
|
| 5101 |
all_aux_stats.append(("versatile", item))
|
| 5102 |
break
|
| 5103 |
|
|
@@ -5105,16 +3972,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5105 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 5106 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 5107 |
|
| 5108 |
-
# ββ DCA final GRN ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5109 |
-
# Aggregates the full depth history (k-selected) into the final
|
| 5110 |
-
# hidden representation, matching the DCAGPT forward loop which
|
| 5111 |
-
# applies final_grn(stack(all_tokens)) before norm β lm_head.
|
| 5112 |
-
if use_dca and self.dca_final_grn is not None:
|
| 5113 |
-
final_stack = dca_select_layers(
|
| 5114 |
-
torch.stack(dca_all_tokens, dim=0), k=_dca_k
|
| 5115 |
-
)
|
| 5116 |
-
hidden_states = self.dca_final_grn(final_stack)
|
| 5117 |
-
|
| 5118 |
hidden_states = self.norm(hidden_states)
|
| 5119 |
|
| 5120 |
if output_hidden_states:
|
|
@@ -5127,9 +3984,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5127 |
analysis_state.attn_res_sources_final = (
|
| 5128 |
attn_res_sources if use_attn_res else None
|
| 5129 |
)
|
| 5130 |
-
analysis_state.dca_all_tokens_final = (
|
| 5131 |
-
dca_all_tokens if use_dca else None
|
| 5132 |
-
)
|
| 5133 |
|
| 5134 |
if not return_dict:
|
| 5135 |
return tuple(
|
|
@@ -5270,7 +4124,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5270 |
layers = None, # filled by NeoLLMModel.forward
|
| 5271 |
jtokm_aux_stats = [] if cfg.use_jtokm else None,
|
| 5272 |
attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
|
| 5273 |
-
dca_all_tokens_final = [] if getattr(cfg, "use_dca", False) else None,
|
| 5274 |
)
|
| 5275 |
|
| 5276 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -5408,11 +4261,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5408 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 5409 |
|
| 5410 |
__all__ = [
|
| 5411 |
-
"StackMemory",
|
| 5412 |
-
"LAuReLLayer",
|
| 5413 |
-
"NeoLLMMUDDModule",
|
| 5414 |
-
"NeoLLMGRN",
|
| 5415 |
-
"dca_select_layers",
|
| 5416 |
"NeoLLMForCausalLM",
|
| 5417 |
"NeoLLMModel",
|
| 5418 |
"NeoLLMPreTrainedModel",
|
|
@@ -5430,7 +4278,7 @@ __all__ = [
|
|
| 5430 |
"REPOModule",
|
| 5431 |
"VersatileFFN",
|
| 5432 |
"compute_versatile_aux_loss",
|
| 5433 |
-
# Analysis dataclasses
|
| 5434 |
"AnalysisState",
|
| 5435 |
"LayerAnalysis",
|
| 5436 |
"AttentionAnalysis",
|
|
@@ -5444,9 +4292,6 @@ __all__ = [
|
|
| 5444 |
"VersatileFFNAnalysis",
|
| 5445 |
"JTokMAnalysis",
|
| 5446 |
"AttnResAnalysis",
|
| 5447 |
-
"DCAAnalysis",
|
| 5448 |
-
"StackMemoryAnalysis",
|
| 5449 |
-
"LAuReLAnalysis",
|
| 5450 |
"GeneratorAnalysis",
|
| 5451 |
]
|
| 5452 |
|
|
|
|
| 80 |
from configuration_neollm import NeoLLMConfig
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 84 |
logger = logging.get_logger(__name__)
|
| 85 |
|
| 86 |
|
|
|
|
| 339 |
lns_scale: Optional[float] = None # 1/β(2β) scaling factor
|
| 340 |
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
@dataclass
|
| 343 |
class AttnResAnalysis:
|
| 344 |
"""
|
|
|
|
| 350 |
sources_count: Optional[int] = None # number of sources including partial
|
| 351 |
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
@dataclass
|
| 354 |
class LayerAnalysis:
|
| 355 |
"""
|
|
|
|
| 378 |
gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
|
| 379 |
|
| 380 |
# Optional components (None when inactive)
|
| 381 |
+
jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
|
| 382 |
+
attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
|
| 385 |
@dataclass
|
|
|
|
| 444 |
layers: Optional[List[LayerAnalysis]] = None
|
| 445 |
jtokm_aux_stats: Optional[list] = None
|
| 446 |
attn_res_sources_final: Optional[list] = None
|
|
|
|
| 447 |
logits: Optional[torch.Tensor] = None
|
| 448 |
|
| 449 |
class ScalarMultiplier(nn.Module):
|
|
|
|
| 2363 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 2364 |
attn_analysis: Optional[AttentionAnalysis] = None,
|
| 2365 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
|
|
|
|
|
| 2366 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 2367 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 2368 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 2373 |
h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
|
| 2374 |
current_layer_fan = h_fan.clone()
|
| 2375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2376 |
query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
|
| 2377 |
kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
|
| 2378 |
|
|
|
|
| 2387 |
attn_analysis.gate_raw = gate.detach()
|
| 2388 |
|
| 2389 |
q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
|
| 2390 |
+
k = self.k_norm(self.k_proj(h_fan).view(kv_shape)).transpose(1, 2)
|
| 2391 |
+
v = self.v_proj(h_fan).view(kv_shape).transpose(1, 2)
|
| 2392 |
|
| 2393 |
if attn_analysis is not None:
|
| 2394 |
attn_analysis.q_post_norm = q.detach()
|
|
|
|
| 3065 |
return result
|
| 3066 |
|
| 3067 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3068 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 3069 |
"""
|
| 3070 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
|
|
| 3120 |
self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3121 |
self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
|
| 3122 |
self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3123 |
else:
|
| 3124 |
self.attn_res_query_attn = None
|
| 3125 |
self.attn_res_query_mlp = None
|
| 3126 |
self.attn_res_norm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3127 |
|
| 3128 |
def _attn_res(
|
| 3129 |
self,
|
|
|
|
| 3173 |
B_vals: Optional[torch.Tensor] = None,
|
| 3174 |
attn_res_sources: Optional[list] = None,
|
| 3175 |
attn_res_partial: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3176 |
layer_analysis: Optional[LayerAnalysis] = None,
|
| 3177 |
output_attentions: Optional[bool] = False,
|
| 3178 |
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
|
|
|
| 3182 |
if layer_analysis is not None:
|
| 3183 |
layer_analysis.hidden_states_input = hidden_states.detach()
|
| 3184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3185 |
# ββ Attention Residuals: compute pre-attention input ββββββββββββββ
|
| 3186 |
# When active, the input to the attention sublayer is no longer the
|
| 3187 |
# raw hidden_states (accumulated residual) but a softmax-weighted
|
|
|
|
| 3195 |
attn_res_sources, attn_res_partial, self.attn_res_query_attn,
|
| 3196 |
ar_analysis, "attn",
|
| 3197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3198 |
residual_attn = attn_res_partial
|
| 3199 |
else:
|
| 3200 |
+
h_attn = hidden_states
|
| 3201 |
+
residual_attn = hidden_states
|
|
|
|
| 3202 |
|
| 3203 |
# ββ Attention block βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3204 |
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
|
|
|
|
| 3214 |
first_layer_fan=first_layer_fan,
|
| 3215 |
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 3216 |
repo_rope_args=repo_rope_args,
|
|
|
|
|
|
|
| 3217 |
**kwargs,
|
| 3218 |
)
|
| 3219 |
|
|
|
|
| 3221 |
layer_analysis.attn_contribution = hidden_states.detach()
|
| 3222 |
|
| 3223 |
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
|
| 3224 |
+
h_tilde = self.gpas_attn(residual_attn + hidden_states, analysis=gpas_attn_a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3225 |
|
| 3226 |
if layer_analysis is not None:
|
| 3227 |
layer_analysis.h_tilde = h_tilde.detach()
|
|
|
|
| 3257 |
if layer_analysis is not None:
|
| 3258 |
layer_analysis.mlp_contribution = delta_m.detach()
|
| 3259 |
|
| 3260 |
+
# ββ JTok-M injection (additive alongside MLP residual) ββββββββββββ
|
| 3261 |
+
aux_stats = None
|
|
|
|
|
|
|
|
|
|
| 3262 |
if self.use_jtokm and z_tilde is not None and B_vals is not None:
|
| 3263 |
orig_shape = h_tilde.shape
|
| 3264 |
h_flat = h_tilde.reshape(-1, self.hidden_size)
|
|
|
|
| 3269 |
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
|
| 3270 |
delta_r = delta_r.reshape(orig_shape)
|
| 3271 |
|
| 3272 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 3273 |
+
hidden_states = self.gpas_mlp(residual_mlp + delta_m + delta_r, analysis=gpas_mlp_a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3274 |
else:
|
| 3275 |
+
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
|
| 3276 |
+
hidden_states = self.gpas_mlp(residual_mlp + delta_m, analysis=gpas_mlp_a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3277 |
|
| 3278 |
if layer_analysis is not None:
|
| 3279 |
layer_analysis.hidden_states_output = hidden_states.detach()
|
|
|
|
| 3285 |
outputs += (aux_stats,)
|
| 3286 |
if versatile_aux is not None:
|
| 3287 |
outputs += (versatile_aux,)
|
|
|
|
|
|
|
|
|
|
| 3288 |
return outputs
|
| 3289 |
|
| 3290 |
|
|
|
|
| 3668 |
module.attn_res_query_attn.data.zero_()
|
| 3669 |
module.attn_res_query_mlp.data.zero_()
|
| 3670 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3671 |
elif isinstance(module, SpellingBeeEmbedding):
|
| 3672 |
# byte_emb initialised identically to token embeddings: std=1/βd.
|
| 3673 |
# Ensures E[βe_byteβΒ²] β 1 at init, matching etok, so the
|
|
|
|
| 3737 |
self.gradient_checkpointing = False
|
| 3738 |
self.first_layer_fan = None
|
| 3739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3740 |
self.post_init()
|
| 3741 |
|
| 3742 |
def get_input_embeddings(self):
|
|
|
|
| 3764 |
getattr(cfg, "use_repo", False)
|
| 3765 |
and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
|
| 3766 |
)
|
| 3767 |
+
_versatile = getattr(cfg, "use_versatile_ffn", False)
|
|
|
|
|
|
|
| 3768 |
return LayerAnalysis(
|
| 3769 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 3770 |
seednorm_post_attn = SeeDNormAnalysis(),
|
|
|
|
| 3778 |
polynorm = PolyNormAnalysis() if not _versatile else None,
|
| 3779 |
versatile = VersatileFFNAnalysis() if _versatile else None,
|
| 3780 |
),
|
| 3781 |
+
gpas_attn = GPASAnalysis(),
|
| 3782 |
+
gpas_mlp = GPASAnalysis(),
|
| 3783 |
+
jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
|
| 3784 |
+
attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3785 |
)
|
| 3786 |
|
| 3787 |
def forward(
|
|
|
|
| 3897 |
if use_attn_res:
|
| 3898 |
attn_res_sources = [hidden_states] # b_0 = token embedding
|
| 3899 |
attn_res_partial = hidden_states # initial partial sum
|
| 3900 |
+
|
| 3901 |
+
num_blocks = getattr(self.config, 'attn_res_num_blocks', 0)
|
| 3902 |
+
block_size = (
|
| 3903 |
+
max(self.config.num_hidden_layers // num_blocks, 1)
|
| 3904 |
+
if num_blocks > 0
|
| 3905 |
+
else 1 # Full AttnRes: every layer is its own "block"
|
| 3906 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3907 |
|
| 3908 |
# Pre-allocate per-layer analysis list when analysis is active
|
| 3909 |
if analysis_state is not None:
|
|
|
|
| 3913 |
if output_hidden_states:
|
| 3914 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 3915 |
|
| 3916 |
+
# ββ Block AttnRes: boundary handling ββββββββββββββββββββββββββ
|
| 3917 |
+
# At each block boundary (excluding layer 0): append the current
|
| 3918 |
+
# partial sum to sources as a completed block summary, then reset
|
| 3919 |
+
# partial to None so the new block builds from scratch β matching
|
| 3920 |
+
# the paper's pseudocode exactly.
|
| 3921 |
+
# For Full AttnRes (block_size=1): every layer is a boundary, so
|
| 3922 |
+
# partial is appended and reset after every layer. The partial is
|
| 3923 |
+
# re-seeded from the previous hidden_states below.
|
| 3924 |
+
if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0:
|
| 3925 |
+
attn_res_sources = attn_res_sources + [attn_res_partial]
|
| 3926 |
+
attn_res_partial = hidden_states # start new block from current output
|
| 3927 |
|
| 3928 |
# Build per-layer analysis container (only in eval + analysis mode)
|
| 3929 |
layer_analysis = None
|
|
|
|
| 3941 |
B_vals=B_vals,
|
| 3942 |
attn_res_sources=attn_res_sources,
|
| 3943 |
attn_res_partial=attn_res_partial if use_attn_res else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3944 |
layer_analysis=layer_analysis,
|
| 3945 |
output_attentions=output_attentions,
|
| 3946 |
repo_rope_args=repo_rope_args,
|
|
|
|
| 3948 |
)
|
| 3949 |
hidden_states = layer_outputs[0]
|
| 3950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3951 |
# Update AttnRes partial sum β the new partial is the layer output
|
| 3952 |
if use_attn_res:
|
| 3953 |
attn_res_partial = hidden_states
|
| 3954 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3955 |
if output_attentions:
|
| 3956 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 3957 |
|
| 3958 |
+
# Collect JTok-M aux stats (last element if present)
|
| 3959 |
+
if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1):
|
| 3960 |
+
all_aux_stats.append(layer_outputs[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3961 |
|
| 3962 |
+
# Collect VersatileFFN aux stats (second-to-last if jtokm also present,
|
| 3963 |
+
# or last if jtokm is absent). Only non-None during training.
|
| 3964 |
if getattr(self.config, "use_versatile_ffn", False):
|
| 3965 |
+
for item in layer_outputs[1:]:
|
| 3966 |
if isinstance(item, tuple) and len(item) == 3:
|
| 3967 |
+
# (p_sum, f_sum, N_tokens) signature
|
| 3968 |
all_aux_stats.append(("versatile", item))
|
| 3969 |
break
|
| 3970 |
|
|
|
|
| 3972 |
and hasattr(decoder_layer, "current_layer_fan")):
|
| 3973 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 3974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3975 |
hidden_states = self.norm(hidden_states)
|
| 3976 |
|
| 3977 |
if output_hidden_states:
|
|
|
|
| 3984 |
analysis_state.attn_res_sources_final = (
|
| 3985 |
attn_res_sources if use_attn_res else None
|
| 3986 |
)
|
|
|
|
|
|
|
|
|
|
| 3987 |
|
| 3988 |
if not return_dict:
|
| 3989 |
return tuple(
|
|
|
|
| 4124 |
layers = None, # filled by NeoLLMModel.forward
|
| 4125 |
jtokm_aux_stats = [] if cfg.use_jtokm else None,
|
| 4126 |
attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
|
|
|
|
| 4127 |
)
|
| 4128 |
|
| 4129 |
# ββ Standard model API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 4261 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 4262 |
|
| 4263 |
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4264 |
"NeoLLMForCausalLM",
|
| 4265 |
"NeoLLMModel",
|
| 4266 |
"NeoLLMPreTrainedModel",
|
|
|
|
| 4278 |
"REPOModule",
|
| 4279 |
"VersatileFFN",
|
| 4280 |
"compute_versatile_aux_loss",
|
| 4281 |
+
# Analysis dataclasses β exported so external tools can type-hint against them
|
| 4282 |
"AnalysisState",
|
| 4283 |
"LayerAnalysis",
|
| 4284 |
"AttentionAnalysis",
|
|
|
|
| 4292 |
"VersatileFFNAnalysis",
|
| 4293 |
"JTokMAnalysis",
|
| 4294 |
"AttnResAnalysis",
|
|
|
|
|
|
|
|
|
|
| 4295 |
"GeneratorAnalysis",
|
| 4296 |
]
|
| 4297 |
|