KitsuVp commited on
Commit
cdb1739
Β·
verified Β·
1 Parent(s): b964bba

Upload modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +49 -485
modeling_neollm.py CHANGED
@@ -81,30 +81,6 @@ from configuration_neollm import NeoLLMConfig
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
 
84
- # ── Optional fast-path dependencies (GatedDeltaNet linear attention) ─────────
85
- try:
86
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update as _causal_conv1d_update
87
- _causal_conv1d_available = True
88
- except ImportError:
89
- causal_conv1d_fn = None
90
- _causal_conv1d_update = None
91
- _causal_conv1d_available = False
92
-
93
- try:
94
- from fla.modules import FusedRMSNormGated
95
- from fla.ops.gated_delta_rule import (
96
- chunk_gated_delta_rule,
97
- fused_recurrent_gated_delta_rule,
98
- )
99
- _fla_available = True
100
- except ImportError:
101
- FusedRMSNormGated = None
102
- chunk_gated_delta_rule = None
103
- fused_recurrent_gated_delta_rule = None
104
- _fla_available = False
105
-
106
- is_linear_attn_fast_path = _causal_conv1d_available and _fla_available
107
-
108
  logger = logging.get_logger(__name__)
109
 
110
 
@@ -428,6 +404,40 @@ class StackMemoryAnalysis:
428
  residual_scale: Optional[float] = None # res_weight scalar
429
 
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  @dataclass
432
  class LayerAnalysis:
433
  """
@@ -460,8 +470,8 @@ class LayerAnalysis:
460
  attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
461
  dca: Optional[DCAAnalysis] = None # if use_dca
462
  stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
463
- laurel_attn: Optional["LAuReLAnalysis"] = None # if use_laurel (attention residual)
464
- laurel_mlp: Optional["LAuReLAnalysis"] = None # if use_laurel (MLP residual)
465
 
466
 
467
  @dataclass
@@ -3685,39 +3695,6 @@ class StackMemory(nn.Module):
3685
 
3686
 
3687
  @dataclass
3688
- class LAuReLAnalysis:
3689
- """
3690
- Internals of one LAuReL residual connection forward pass.
3691
- Only populated when use_laurel=True AND model is in eval + analysis mode.
3692
- Instantiated twice per layer: once for the attention residual, once for MLP.
3693
-
3694
- Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
3695
- *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
3696
-
3697
- Math (combined RW+LR, both sub-variants active):
3698
-
3699
- x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
3700
-
3701
- where [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ learnable (RW component),
3702
- B ∈ ℝ^{rΓ—D} column-orthogonal init, A ∈ ℝ^{DΓ—r} zero init (LR component).
3703
- At step 0: A=0 β†’ lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
3704
- or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
3705
-
3706
- Fields:
3707
- alpha_rw: softmax(a) β€” weight on f(x_i). [scalar float]
3708
- None when use_laurel_rw=False.
3709
- beta_rw: softmax(b) β€” weight on g(x_i). [scalar float]
3710
- None when use_laurel_rw=False.
3711
- lr_term: AΒ·(BΒ·x_res) β€” the low-rank residual augmentation.
3712
- Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
3713
- output: Final combined tensor before GPAS. Shape [B, S, D].
3714
- """
3715
- alpha_rw: Optional[float] = None # softmax weight on f(x)
3716
- beta_rw: Optional[float] = None # softmax weight on g(x)
3717
- lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
3718
- output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
3719
-
3720
-
3721
  class LAuReLLayer(nn.Module):
3722
  """
3723
  LAuReL: Learned Augmented Residual Layer.
@@ -3836,358 +3813,6 @@ class LAuReLLayer(nn.Module):
3836
  return out
3837
 
3838
 
3839
- # ==================== GATED DELTA NET (LINEAR ATTENTION) ====================
3840
- # Active when use_linear_attention=True. Replaces NeoLLMAttention every
3841
- # `linear_attention_every_n` layers (pattern 0-indexed: layers 2, 5, 8 …).
3842
- #
3843
- # References:
3844
- # Yang et al. (2024). "Gated Delta Networks." arXiv:2412.06464.
3845
- # Li et al. (2026). "REPO." arXiv:2512.14391.
3846
-
3847
-
3848
- def _apply_mask_to_padding_states(
3849
- hidden_states: torch.Tensor,
3850
- attention_mask: Optional[torch.Tensor],
3851
- ) -> torch.Tensor:
3852
- if (
3853
- attention_mask is not None
3854
- and attention_mask.shape[1] > 1
3855
- and attention_mask.shape[0] > 1
3856
- ):
3857
- hidden_states = (
3858
- hidden_states * attention_mask[:, :, None]
3859
- ).to(hidden_states.dtype)
3860
- return hidden_states
3861
-
3862
-
3863
- def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
3864
- return x / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
3865
-
3866
-
3867
- def _torch_causal_conv1d_update(
3868
- hidden_states, conv_state, weight, bias=None, activation=None
3869
- ):
3870
- _, hidden_size, seq_len = hidden_states.shape
3871
- state_len = conv_state.shape[-1]
3872
- combined = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
3873
- conv_state.copy_(combined[:, :, -state_len:])
3874
- out = F.conv1d(combined, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
3875
- return F.silu(out[:, :, -seq_len:]).to(hidden_states.dtype)
3876
-
3877
-
3878
- def _torch_chunk_gated_delta_rule(
3879
- query, key, value, g, beta,
3880
- chunk_size=64, initial_state=None, output_final_state=False,
3881
- use_qk_l2norm_in_kernel=False,
3882
- ):
3883
- initial_dtype = query.dtype
3884
- if use_qk_l2norm_in_kernel:
3885
- query, key = _l2norm(query), _l2norm(key)
3886
- query, key, value, beta, g = [
3887
- x.transpose(1, 2).contiguous().to(torch.float32)
3888
- for x in (query, key, value, beta, g)
3889
- ]
3890
- bs, seq, nh, kdim = key.shape
3891
- vdim = value.shape[-1]
3892
- pad = (chunk_size - nh % chunk_size) % chunk_size
3893
- for t in (query, key, value):
3894
- t = F.pad(t, (0, 0, 0, pad))
3895
- query = F.pad(query, (0, 0, 0, pad))
3896
- key = F.pad(key, (0, 0, 0, pad))
3897
- value = F.pad(value, (0, 0, 0, pad))
3898
- beta = F.pad(beta, (0, pad))
3899
- g = F.pad(g, (0, pad))
3900
- tot = nh + pad
3901
- scale = query.shape[-1] ** -0.5
3902
- query = query * scale
3903
- vb = value * beta.unsqueeze(-1)
3904
- kb = key * beta.unsqueeze(-1)
3905
- query, key, value, kb, vb = [
3906
- x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
3907
- for x in (query, key, value, kb, vb)
3908
- ]
3909
- g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
3910
- triu = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 0)
3911
- g = g.cumsum(-1)
3912
- dm = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp()).tril()
3913
- attn = -((kb @ key.transpose(-1, -2)) * dm).masked_fill(triu, 0)
3914
- for i in range(1, chunk_size):
3915
- r = attn[..., i, :i].clone(); s = attn[..., :i, :i].clone()
3916
- attn[..., i, :i] = r + (r.unsqueeze(-1) * s).sum(-2)
3917
- eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
3918
- attn = attn + eye
3919
- value = attn @ vb
3920
- kcd = attn @ (kb * g.exp().unsqueeze(-1))
3921
- st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
3922
- out = torch.zeros_like(value)
3923
- triu2 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1)
3924
- for i in range(tot // chunk_size):
3925
- qi, ki, vi = query[:,:,i], key[:,:,i], value[:,:,i]
3926
- a = (qi @ ki.transpose(-1,-2) * dm[:,:,i]).masked_fill_(triu2, 0)
3927
- vp = kcd[:,:,i] @ st
3928
- vn = vi - vp
3929
- out[:,:,i] = (qi * g[:,:,i,:,None].exp()) @ st + a @ vn
3930
- st = st * g[:,:,i,-1,None,None].exp() + (ki * (g[:,:,i,-1,None]-g[:,:,i]).exp()[...,None]).transpose(-1,-2) @ vn
3931
- if not output_final_state: st = None
3932
- out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1])[:,:,:nh]
3933
- return out.transpose(1,2).contiguous().to(initial_dtype), st
3934
-
3935
-
3936
- def _torch_recurrent_gated_delta_rule(
3937
- query, key, value, g, beta, initial_state, output_final_state,
3938
- use_qk_l2norm_in_kernel=False,
3939
- ):
3940
- initial_dtype = query.dtype
3941
- if use_qk_l2norm_in_kernel:
3942
- query, key = _l2norm(query), _l2norm(key)
3943
- query, key, value, beta, g = [
3944
- x.transpose(1,2).contiguous().to(torch.float32)
3945
- for x in (query, key, value, beta, g)
3946
- ]
3947
- bs, seq, nh, kdim = key.shape
3948
- vdim = value.shape[-1]
3949
- query = query * (query.shape[-1] ** -0.5)
3950
- out = torch.zeros(bs, seq, nh, vdim, dtype=value.dtype, device=value.device)
3951
- st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
3952
- for i in range(nh):
3953
- qt, kt, vt = query[:,:,i], key[:,:,i], value[:,:,i]
3954
- gt, bt = g[:,:,i].exp().unsqueeze(-1).unsqueeze(-1), beta[:,:,i].unsqueeze(-1)
3955
- st = st * gt
3956
- delta = (vt - (st * kt.unsqueeze(-1)).sum(-2)) * bt
3957
- st = st + kt.unsqueeze(-1) * delta.unsqueeze(-2)
3958
- out[:,:,i] = (st * qt.unsqueeze(-1)).sum(-2)
3959
- if not output_final_state: st = None
3960
- return out.transpose(1,2).contiguous().to(initial_dtype), st
3961
-
3962
-
3963
- class _NeoLLMRMSNormGated(nn.Module):
3964
- """Gated RMSNorm fallback when FLA unavailable."""
3965
- def __init__(self, hidden_size, eps=1e-6, **kwargs):
3966
- super().__init__()
3967
- self.weight = nn.Parameter(torch.ones(hidden_size))
3968
- self.eps = eps
3969
- def forward(self, x, gate):
3970
- dtype = x.dtype
3971
- x = x.float()
3972
- x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
3973
- return (self.weight * x.to(dtype) * F.silu(gate.float())).to(dtype)
3974
-
3975
-
3976
- class NeoLLMGatedDeltaNet(nn.Module):
3977
- """
3978
- GatedDeltaNet linear attention with FANformer integration.
3979
-
3980
- Replaces NeoLLMAttention on every ``linear_attention_every_n``-th layer
3981
- (0-indexed: layers 2, 5, 8 … for every_n=3).
3982
-
3983
- REPO (use_repo=True AND use_repo_in_linear_attn=True):
3984
- Applies continuous per-head positions to Q and K via _apply_repo_rope,
3985
- matching the full-attention REPO path identically.
3986
-
3987
- Without REPO the gated delta rule operates without explicit positional
3988
- encoding (its recurrent state is implicitly position-aware).
3989
-
3990
- References:
3991
- Yang et al. (2024). arXiv:2412.06464.
3992
- Li et al. (2026). arXiv:2512.14391.
3993
- """
3994
-
3995
- def __init__(self, config: NeoLLMConfig, layer_idx: int):
3996
- super().__init__()
3997
- self.hidden_size = config.hidden_size
3998
- self.num_v_heads = config.linear_num_value_heads
3999
- self.num_k_heads = config.linear_num_key_heads
4000
- self.head_k_dim = config.linear_key_head_dim
4001
- self.head_v_dim = config.linear_value_head_dim
4002
- self.key_dim = self.head_k_dim * self.num_k_heads
4003
- self.value_dim = self.head_v_dim * self.num_v_heads
4004
- self.conv_kernel_size = config.linear_conv_kernel_dim
4005
- self.layer_idx = layer_idx
4006
-
4007
- # ── FANformer (same ratio as full-attention layers) ────────────────
4008
- self.fan_layer = FANLayer(
4009
- hidden_size=config.hidden_size,
4010
- fan_ratio=getattr(config, "fan_ratio", 0.125),
4011
- )
4012
- _fan_dim = config.hidden_size + int(
4013
- config.hidden_size * getattr(config, "fan_ratio", 0.125)
4014
- )
4015
-
4016
- # ── Causal conv1d on concatenated Q/K/V ──────────────────────────
4017
- self.conv_dim = self.key_dim * 2 + self.value_dim
4018
- self.conv1d = nn.Conv1d(
4019
- self.conv_dim, self.conv_dim, bias=False,
4020
- kernel_size=self.conv_kernel_size,
4021
- groups=self.conv_dim,
4022
- padding=self.conv_kernel_size - 1,
4023
- )
4024
-
4025
- # ── QKVz + ba projections (all from FAN-transformed features) ─────
4026
- _ratio = self.num_v_heads // self.num_k_heads
4027
- self.in_proj_qkvz = nn.Linear(
4028
- _fan_dim, self.key_dim * 2 + self.value_dim * 2, bias=False
4029
- )
4030
- self.in_proj_ba = nn.Linear(
4031
- _fan_dim, self.num_v_heads * 2, bias=False
4032
- )
4033
-
4034
- # ── Delta-rule gating parameters ──────────────────────────────────
4035
- self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
4036
- A = torch.empty(self.num_v_heads).uniform_(0, 16)
4037
- self.A_log = nn.Parameter(torch.log(A))
4038
-
4039
- # ── Output normalisation ───────────────────────────────────────────
4040
- _NormCls = FusedRMSNormGated if FusedRMSNormGated is not None else _NeoLLMRMSNormGated
4041
- _norm_kw = (
4042
- dict(activation="silu",
4043
- device=torch.cuda.current_device(),
4044
- dtype=getattr(config, "dtype", None) or torch.get_default_dtype())
4045
- if FusedRMSNormGated is not None else {}
4046
- )
4047
- self.norm = _NormCls(self.head_v_dim, eps=config.rms_norm_eps, **_norm_kw)
4048
- self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
4049
- self.dropout = nn.Dropout(config.dropout_rate)
4050
-
4051
- # ── Kernel dispatch (fast β†’ fallback) ─────────────────────────────
4052
- self._conv1d_fn = causal_conv1d_fn # None if not installed
4053
- self._chunk_fn = (chunk_gated_delta_rule
4054
- if chunk_gated_delta_rule is not None
4055
- else _torch_chunk_gated_delta_rule)
4056
- self._recur_fn = (fused_recurrent_gated_delta_rule
4057
- if fused_recurrent_gated_delta_rule is not None
4058
- else _torch_recurrent_gated_delta_rule)
4059
-
4060
- if not is_linear_attn_fast_path:
4061
- logger.warning_once(
4062
- "NeoLLMGatedDeltaNet: causal_conv1d / flash-linear-attention "
4063
- "not installed β€” using pure-PyTorch fallbacks. "
4064
- "Install both packages for full performance."
4065
- )
4066
-
4067
- # ── REPO: continuous per-head positions on Q and K ─────────────────
4068
- # Controlled by use_repo AND use_repo_in_linear_attn flags.
4069
- # Only active for layers at or above repo_start_layer.
4070
- self.use_repo = (
4071
- getattr(config, "use_repo", False)
4072
- and getattr(config, "use_repo_in_linear_attn", False)
4073
- and layer_idx >= getattr(config, "repo_start_layer",
4074
- config.num_hidden_layers // 3)
4075
- )
4076
- if self.use_repo:
4077
- _d_p = getattr(config, "repo_d_p", config.hidden_size // 8)
4078
- self.repo_module = REPOModule(
4079
- hidden_size=config.hidden_size,
4080
- d_p=_d_p,
4081
- num_heads=self.num_v_heads,
4082
- )
4083
- else:
4084
- self.repo_module = None
4085
-
4086
- def _fix_qkvz(
4087
- self,
4088
- mixed_qkvz: torch.Tensor,
4089
- mixed_ba: torch.Tensor,
4090
- ) -> Tuple[torch.Tensor, ...]:
4091
- """Split fused projection into (q, k, v, z, b, a)."""
4092
- ratio = self.num_v_heads // self.num_k_heads
4093
- mixed_qkvz = mixed_qkvz.view(
4094
- *mixed_qkvz.shape[:-1],
4095
- self.num_k_heads,
4096
- 2 * self.head_k_dim + 2 * ratio * self.head_v_dim,
4097
- )
4098
- mixed_ba = mixed_ba.view(
4099
- *mixed_ba.shape[:-1],
4100
- self.num_k_heads,
4101
- 2 * ratio,
4102
- )
4103
- q, k, v, z = torch.split(
4104
- mixed_qkvz,
4105
- [self.head_k_dim, self.head_k_dim,
4106
- ratio * self.head_v_dim, ratio * self.head_v_dim],
4107
- dim=3,
4108
- )
4109
- b, a = torch.split(mixed_ba, ratio, dim=3)
4110
- v = v.reshape(v.shape[0], v.shape[1], -1, self.head_v_dim)
4111
- z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim)
4112
- b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads)
4113
- a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads)
4114
- return q, k, v, z, b, a
4115
-
4116
- def forward(
4117
- self,
4118
- hidden_states: torch.Tensor,
4119
- attention_mask: Optional[torch.Tensor] = None,
4120
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
4121
- repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
4122
- ) -> torch.Tensor:
4123
- hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
4124
- B, S, _ = hidden_states.shape
4125
-
4126
- # ── FANformer ─────────────────────────────────────────────────────
4127
- h_fan = self.fan_layer(hidden_states)
4128
-
4129
- # ── QKVz and ba projections ───────────────────────────────────────
4130
- q, k, v, z, b, a = self._fix_qkvz(
4131
- self.in_proj_qkvz(h_fan), self.in_proj_ba(h_fan)
4132
- )
4133
-
4134
- # ── Causal conv1d on flattened Q/K/V ─────────────────────────────
4135
- qkv = torch.cat(
4136
- [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], dim=-1
4137
- ).transpose(1, 2) # [B, conv_dim, S]
4138
-
4139
- if self._conv1d_fn is not None:
4140
- qkv = self._conv1d_fn(
4141
- x=qkv, weight=self.conv1d.weight.squeeze(1),
4142
- bias=self.conv1d.bias, activation="silu", seq_idx=None,
4143
- )
4144
- else:
4145
- qkv = F.silu(self.conv1d(qkv)[:, :, :S])
4146
- qkv = qkv.transpose(1, 2) # [B, S, conv_dim]
4147
-
4148
- q_f, k_f, v_f = torch.split(
4149
- qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1
4150
- )
4151
- q = q_f.reshape(B, S, -1, self.head_k_dim)
4152
- k = k_f.reshape(B, S, -1, self.head_k_dim)
4153
- v = v_f.reshape(B, S, -1, self.head_v_dim)
4154
-
4155
- # ── REPO: continuous per-head positions ───────────────────────────
4156
- # Transpose to [B, H, S, dk] for _apply_repo_rope, then back.
4157
- if self.use_repo and self.repo_module is not None and repo_rope_args is not None:
4158
- inv_freq, attn_scaling = repo_rope_args
4159
- z_pos = self.repo_module(hidden_states) # [B, H, S]
4160
- q_t, k_t = q.transpose(1, 2), k.transpose(1, 2)
4161
- q_t, k_t = _apply_repo_rope(q_t, k_t, z_pos, inv_freq, attn_scaling)
4162
- q, k = q_t.transpose(1, 2), k_t.transpose(1, 2)
4163
-
4164
- # ── GQA-like head expansion ─────────────────────────────��──────────
4165
- ratio = self.num_v_heads // self.num_k_heads
4166
- if ratio > 1:
4167
- q = q.repeat_interleave(ratio, dim=2)
4168
- k = k.repeat_interleave(ratio, dim=2)
4169
-
4170
- beta = b.sigmoid()
4171
- g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
4172
-
4173
- # ── Chunk gated delta rule (fused or fallback) ────────────────────
4174
- core_out, _ = self._chunk_fn(
4175
- q, k, v, g=g, beta=beta,
4176
- initial_state=None, output_final_state=False,
4177
- use_qk_l2norm_in_kernel=True,
4178
- )
4179
-
4180
- # ── Gated RMSNorm + output projection ─────────────────────────────
4181
- z_shape = z.shape
4182
- core_out = core_out.reshape(-1, core_out.shape[-1])
4183
- core_out = self.norm(core_out, z.reshape(-1, z.shape[-1]))
4184
- core_out = core_out.reshape(z_shape).reshape(B, S, -1)
4185
- return self.dropout(self.out_proj(core_out))
4186
-
4187
-
4188
- # ==================== DECODER LAYER ==========================================
4189
-
4190
-
4191
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4192
  """
4193
  Decoder layer with standard residual connections, optional JTok-M injection.
@@ -4210,23 +3835,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4210
  self.layer_idx = layer_idx
4211
  self.use_jtokm = config.use_jtokm
4212
 
4213
- # ── Token-mixer selection ─────────────────────────────────────────
4214
- # use_linear_attention=True: replace full attention every
4215
- # `linear_attention_every_n` layers (0-indexed pattern:
4216
- # e.g. every_n=3 β†’ layers 2, 5, 8, 11 …).
4217
- # All other layers keep NeoLLMAttention unchanged.
4218
- _every_n = getattr(config, "linear_attention_every_n", 3)
4219
- self.is_linear_attn = (
4220
- getattr(config, "use_linear_attention", False)
4221
- and (layer_idx + 1) % _every_n == 0
4222
- )
4223
- if self.is_linear_attn:
4224
- self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
4225
- self.self_attn = None
4226
- else:
4227
- self.self_attn = NeoLLMAttention(config, layer_idx)
4228
- self.linear_attn = None
4229
-
4230
  self.mlp = (
4231
  VersatileFFN(config)
4232
  if getattr(config, "use_versatile_ffn", False)
@@ -4484,32 +4093,17 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4484
  if layer_analysis is not None:
4485
  layer_analysis.lns_attn_output = h_lns.detach()
4486
 
4487
- if self.is_linear_attn:
4488
- # ── GatedDeltaNet linear attention path ───────────────────────
4489
- # Does not use: first_layer_fan, mudd_xk/xv, attn_analysis.
4490
- # attention_mask here is already the linear_attn_mask (no causal
4491
- # bias, just padding) β€” NeoLLMModel.forward selects it per layer.
4492
- hidden_states = self.linear_attn(
4493
- hidden_states=h_lns,
4494
- attention_mask=attention_mask,
4495
- position_embeddings=position_embeddings,
4496
- repo_rope_args=repo_rope_args,
4497
- )
4498
- attn_weights = None
4499
- self.current_layer_fan = None
4500
- else:
4501
- # ── Standard full attention path ──────────────────────────────
4502
- hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
4503
- hidden_states=h_lns,
4504
- attention_mask=attention_mask,
4505
- position_embeddings=position_embeddings,
4506
- first_layer_fan=first_layer_fan,
4507
- attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
4508
- repo_rope_args=repo_rope_args,
4509
- mudd_xk=mudd_xk,
4510
- mudd_xv=mudd_xv,
4511
- **kwargs,
4512
- )
4513
 
4514
  if layer_analysis is not None:
4515
  layer_analysis.attn_contribution = hidden_states.detach()
@@ -4933,9 +4527,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
4933
  if hasattr(module, "alpha_ma"):
4934
  module.alpha_ma.zero_()
4935
 
4936
- elif isinstance(module, NeoLLMGatedDeltaNet):
4937
- module.dt_bias.data.fill_(1.0)
4938
- module.A_log.data.uniform_(0, 16).log_()
4939
  elif isinstance(module, GPAS):
4940
  module.alpha.data.fill_(0.0)
4941
 
@@ -5178,21 +4769,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5178
 
5179
  self.post_init()
5180
 
5181
- def _update_linear_attn_mask(
5182
- self,
5183
- attention_mask: Optional[torch.Tensor],
5184
- ) -> Optional[torch.Tensor]:
5185
- """
5186
- Return mask for GatedDeltaNet layers (no causal bias, padding only).
5187
- Returns None when all tokens are valid (GatedDeltaNet handles via
5188
- _apply_mask_to_padding_states internally).
5189
- """
5190
- if attention_mask is None:
5191
- return None
5192
- if torch.all(attention_mask == 1):
5193
- return None
5194
- return attention_mask
5195
-
5196
  def get_input_embeddings(self):
5197
  if self.config.use_token_generator:
5198
  return self.token_generator
@@ -5343,10 +4919,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5343
  if getattr(self.config, "use_repo", False) else None
5344
  )
5345
 
5346
- # ── Linear attention mask ──────────────────────────────────────────
5347
- # Computed once; each layer picks the appropriate mask below.
5348
- linear_attn_mask = self._update_linear_attn_mask(attention_mask)
5349
-
5350
  # ── Attention Residuals state ──────────────────────────────────────
5351
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
5352
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
@@ -5436,17 +5008,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5436
  layer_analysis.layer_idx = layer_idx
5437
  analysis_state.layers.append(layer_analysis)
5438
 
5439
- # Select the appropriate mask: causal for full attention,
5440
- # padding-only for GatedDeltaNet linear attention.
5441
- _layer_mask = (
5442
- linear_attn_mask
5443
- if getattr(decoder_layer, "is_linear_attn", False)
5444
- else causal_mask
5445
- )
5446
  layer_outputs = decoder_layer(
5447
  hidden_states,
5448
  position_embeddings=position_embeddings,
5449
- attention_mask=_layer_mask,
5450
  first_layer_fan=self.first_layer_fan,
5451
  z_tilde=z_tilde,
5452
  B_vals=B_vals,
@@ -5843,7 +5408,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5843
  # ==================== AUTOMODEL REGISTRATION ====================
5844
 
5845
  __all__ = [
5846
- "NeoLLMGatedDeltaNet",
5847
  "StackMemory",
5848
  "LAuReLLayer",
5849
  "NeoLLMMUDDModule",
 
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  logger = logging.get_logger(__name__)
85
 
86
 
 
404
  residual_scale: Optional[float] = None # res_weight scalar
405
 
406
 
407
+ @dataclass
408
+ class LAuReLAnalysis:
409
+ """
410
+ Internals of one LAuReL residual connection forward pass.
411
+ Only populated when use_laurel=True AND model is in eval + analysis mode.
412
+ Instantiated twice per layer: once for the attention residual, once for MLP.
413
+
414
+ Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
415
+ *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
416
+
417
+ Math (combined RW+LR, both sub-variants active):
418
+
419
+ x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
420
+
421
+ where [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ learnable (RW component),
422
+ B ∈ ℝ^{rΓ—D} column-orthogonal init, A ∈ ℝ^{DΓ—r} zero init (LR component).
423
+ At step 0: A=0 β†’ lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
424
+ or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
425
+
426
+ Fields:
427
+ alpha_rw: softmax(a) β€” weight on f(x_i). [scalar float]
428
+ None when use_laurel_rw=False.
429
+ beta_rw: softmax(b) β€” weight on g(x_i). [scalar float]
430
+ None when use_laurel_rw=False.
431
+ lr_term: AΒ·(BΒ·x_res) β€” the low-rank residual augmentation.
432
+ Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
433
+ output: Final combined tensor before GPAS. Shape [B, S, D].
434
+ """
435
+ alpha_rw: Optional[float] = None # softmax weight on f(x)
436
+ beta_rw: Optional[float] = None # softmax weight on g(x)
437
+ lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
438
+ output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
439
+
440
+
441
  @dataclass
442
  class LayerAnalysis:
443
  """
 
470
  attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
471
  dca: Optional[DCAAnalysis] = None # if use_dca
472
  stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
473
+ laurel_attn: Optional[LAuReLAnalysis] = None # if use_laurel (attention residual)
474
+ laurel_mlp: Optional[LAuReLAnalysis] = None # if use_laurel (MLP residual)
475
 
476
 
477
  @dataclass
 
3695
 
3696
 
3697
  @dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3698
  class LAuReLLayer(nn.Module):
3699
  """
3700
  LAuReL: Learned Augmented Residual Layer.
 
3813
  return out
3814
 
3815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3816
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3817
  """
3818
  Decoder layer with standard residual connections, optional JTok-M injection.
 
3835
  self.layer_idx = layer_idx
3836
  self.use_jtokm = config.use_jtokm
3837
 
3838
+ self.self_attn = NeoLLMAttention(config, layer_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3839
  self.mlp = (
3840
  VersatileFFN(config)
3841
  if getattr(config, "use_versatile_ffn", False)
 
4093
  if layer_analysis is not None:
4094
  layer_analysis.lns_attn_output = h_lns.detach()
4095
 
4096
+ hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
4097
+ hidden_states=h_lns,
4098
+ attention_mask=attention_mask,
4099
+ position_embeddings=position_embeddings,
4100
+ first_layer_fan=first_layer_fan,
4101
+ attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
4102
+ repo_rope_args=repo_rope_args,
4103
+ mudd_xk=mudd_xk,
4104
+ mudd_xv=mudd_xv,
4105
+ **kwargs,
4106
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4107
 
4108
  if layer_analysis is not None:
4109
  layer_analysis.attn_contribution = hidden_states.detach()
 
4527
  if hasattr(module, "alpha_ma"):
4528
  module.alpha_ma.zero_()
4529
 
 
 
 
4530
  elif isinstance(module, GPAS):
4531
  module.alpha.data.fill_(0.0)
4532
 
 
4769
 
4770
  self.post_init()
4771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4772
  def get_input_embeddings(self):
4773
  if self.config.use_token_generator:
4774
  return self.token_generator
 
4919
  if getattr(self.config, "use_repo", False) else None
4920
  )
4921
 
 
 
 
 
4922
  # ── Attention Residuals state ──────────────────────────────────────
4923
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
4924
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
 
5008
  layer_analysis.layer_idx = layer_idx
5009
  analysis_state.layers.append(layer_analysis)
5010
 
 
 
 
 
 
 
 
5011
  layer_outputs = decoder_layer(
5012
  hidden_states,
5013
  position_embeddings=position_embeddings,
5014
+ attention_mask=causal_mask,
5015
  first_layer_fan=self.first_layer_fan,
5016
  z_tilde=z_tilde,
5017
  B_vals=B_vals,
 
5408
  # ==================== AUTOMODEL REGISTRATION ====================
5409
 
5410
  __all__ = [
 
5411
  "StackMemory",
5412
  "LAuReLLayer",
5413
  "NeoLLMMUDDModule",