Upload modeling_neollm.py
Browse files- modeling_neollm.py +49 -485
modeling_neollm.py
CHANGED
|
@@ -81,30 +81,6 @@ from configuration_neollm import NeoLLMConfig
|
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
|
| 84 |
-
# ββ Optional fast-path dependencies (GatedDeltaNet linear attention) βββββββββ
|
| 85 |
-
try:
|
| 86 |
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update as _causal_conv1d_update
|
| 87 |
-
_causal_conv1d_available = True
|
| 88 |
-
except ImportError:
|
| 89 |
-
causal_conv1d_fn = None
|
| 90 |
-
_causal_conv1d_update = None
|
| 91 |
-
_causal_conv1d_available = False
|
| 92 |
-
|
| 93 |
-
try:
|
| 94 |
-
from fla.modules import FusedRMSNormGated
|
| 95 |
-
from fla.ops.gated_delta_rule import (
|
| 96 |
-
chunk_gated_delta_rule,
|
| 97 |
-
fused_recurrent_gated_delta_rule,
|
| 98 |
-
)
|
| 99 |
-
_fla_available = True
|
| 100 |
-
except ImportError:
|
| 101 |
-
FusedRMSNormGated = None
|
| 102 |
-
chunk_gated_delta_rule = None
|
| 103 |
-
fused_recurrent_gated_delta_rule = None
|
| 104 |
-
_fla_available = False
|
| 105 |
-
|
| 106 |
-
is_linear_attn_fast_path = _causal_conv1d_available and _fla_available
|
| 107 |
-
|
| 108 |
logger = logging.get_logger(__name__)
|
| 109 |
|
| 110 |
|
|
@@ -428,6 +404,40 @@ class StackMemoryAnalysis:
|
|
| 428 |
residual_scale: Optional[float] = None # res_weight scalar
|
| 429 |
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
@dataclass
|
| 432 |
class LayerAnalysis:
|
| 433 |
"""
|
|
@@ -460,8 +470,8 @@ class LayerAnalysis:
|
|
| 460 |
attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
|
| 461 |
dca: Optional[DCAAnalysis] = None # if use_dca
|
| 462 |
stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
|
| 463 |
-
laurel_attn: Optional[
|
| 464 |
-
laurel_mlp: Optional[
|
| 465 |
|
| 466 |
|
| 467 |
@dataclass
|
|
@@ -3685,39 +3695,6 @@ class StackMemory(nn.Module):
|
|
| 3685 |
|
| 3686 |
|
| 3687 |
@dataclass
|
| 3688 |
-
class LAuReLAnalysis:
|
| 3689 |
-
"""
|
| 3690 |
-
Internals of one LAuReL residual connection forward pass.
|
| 3691 |
-
Only populated when use_laurel=True AND model is in eval + analysis mode.
|
| 3692 |
-
Instantiated twice per layer: once for the attention residual, once for MLP.
|
| 3693 |
-
|
| 3694 |
-
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 3695 |
-
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 3696 |
-
|
| 3697 |
-
Math (combined RW+LR, both sub-variants active):
|
| 3698 |
-
|
| 3699 |
-
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 3700 |
-
|
| 3701 |
-
where [Ξ±, Ξ²] = softmax([a, b]), a,b β β learnable (RW component),
|
| 3702 |
-
B β β^{rΓD} column-orthogonal init, A β β^{DΓr} zero init (LR component).
|
| 3703 |
-
At step 0: A=0 β lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
|
| 3704 |
-
or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
|
| 3705 |
-
|
| 3706 |
-
Fields:
|
| 3707 |
-
alpha_rw: softmax(a) β weight on f(x_i). [scalar float]
|
| 3708 |
-
None when use_laurel_rw=False.
|
| 3709 |
-
beta_rw: softmax(b) β weight on g(x_i). [scalar float]
|
| 3710 |
-
None when use_laurel_rw=False.
|
| 3711 |
-
lr_term: AΒ·(BΒ·x_res) β the low-rank residual augmentation.
|
| 3712 |
-
Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
|
| 3713 |
-
output: Final combined tensor before GPAS. Shape [B, S, D].
|
| 3714 |
-
"""
|
| 3715 |
-
alpha_rw: Optional[float] = None # softmax weight on f(x)
|
| 3716 |
-
beta_rw: Optional[float] = None # softmax weight on g(x)
|
| 3717 |
-
lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
|
| 3718 |
-
output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
|
| 3719 |
-
|
| 3720 |
-
|
| 3721 |
class LAuReLLayer(nn.Module):
|
| 3722 |
"""
|
| 3723 |
LAuReL: Learned Augmented Residual Layer.
|
|
@@ -3836,358 +3813,6 @@ class LAuReLLayer(nn.Module):
|
|
| 3836 |
return out
|
| 3837 |
|
| 3838 |
|
| 3839 |
-
# ==================== GATED DELTA NET (LINEAR ATTENTION) ====================
|
| 3840 |
-
# Active when use_linear_attention=True. Replaces NeoLLMAttention every
|
| 3841 |
-
# `linear_attention_every_n` layers (pattern 0-indexed: layers 2, 5, 8 β¦).
|
| 3842 |
-
#
|
| 3843 |
-
# References:
|
| 3844 |
-
# Yang et al. (2024). "Gated Delta Networks." arXiv:2412.06464.
|
| 3845 |
-
# Li et al. (2026). "REPO." arXiv:2512.14391.
|
| 3846 |
-
|
| 3847 |
-
|
| 3848 |
-
def _apply_mask_to_padding_states(
|
| 3849 |
-
hidden_states: torch.Tensor,
|
| 3850 |
-
attention_mask: Optional[torch.Tensor],
|
| 3851 |
-
) -> torch.Tensor:
|
| 3852 |
-
if (
|
| 3853 |
-
attention_mask is not None
|
| 3854 |
-
and attention_mask.shape[1] > 1
|
| 3855 |
-
and attention_mask.shape[0] > 1
|
| 3856 |
-
):
|
| 3857 |
-
hidden_states = (
|
| 3858 |
-
hidden_states * attention_mask[:, :, None]
|
| 3859 |
-
).to(hidden_states.dtype)
|
| 3860 |
-
return hidden_states
|
| 3861 |
-
|
| 3862 |
-
|
| 3863 |
-
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
|
| 3864 |
-
return x / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
| 3865 |
-
|
| 3866 |
-
|
| 3867 |
-
def _torch_causal_conv1d_update(
|
| 3868 |
-
hidden_states, conv_state, weight, bias=None, activation=None
|
| 3869 |
-
):
|
| 3870 |
-
_, hidden_size, seq_len = hidden_states.shape
|
| 3871 |
-
state_len = conv_state.shape[-1]
|
| 3872 |
-
combined = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
|
| 3873 |
-
conv_state.copy_(combined[:, :, -state_len:])
|
| 3874 |
-
out = F.conv1d(combined, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
|
| 3875 |
-
return F.silu(out[:, :, -seq_len:]).to(hidden_states.dtype)
|
| 3876 |
-
|
| 3877 |
-
|
| 3878 |
-
def _torch_chunk_gated_delta_rule(
|
| 3879 |
-
query, key, value, g, beta,
|
| 3880 |
-
chunk_size=64, initial_state=None, output_final_state=False,
|
| 3881 |
-
use_qk_l2norm_in_kernel=False,
|
| 3882 |
-
):
|
| 3883 |
-
initial_dtype = query.dtype
|
| 3884 |
-
if use_qk_l2norm_in_kernel:
|
| 3885 |
-
query, key = _l2norm(query), _l2norm(key)
|
| 3886 |
-
query, key, value, beta, g = [
|
| 3887 |
-
x.transpose(1, 2).contiguous().to(torch.float32)
|
| 3888 |
-
for x in (query, key, value, beta, g)
|
| 3889 |
-
]
|
| 3890 |
-
bs, seq, nh, kdim = key.shape
|
| 3891 |
-
vdim = value.shape[-1]
|
| 3892 |
-
pad = (chunk_size - nh % chunk_size) % chunk_size
|
| 3893 |
-
for t in (query, key, value):
|
| 3894 |
-
t = F.pad(t, (0, 0, 0, pad))
|
| 3895 |
-
query = F.pad(query, (0, 0, 0, pad))
|
| 3896 |
-
key = F.pad(key, (0, 0, 0, pad))
|
| 3897 |
-
value = F.pad(value, (0, 0, 0, pad))
|
| 3898 |
-
beta = F.pad(beta, (0, pad))
|
| 3899 |
-
g = F.pad(g, (0, pad))
|
| 3900 |
-
tot = nh + pad
|
| 3901 |
-
scale = query.shape[-1] ** -0.5
|
| 3902 |
-
query = query * scale
|
| 3903 |
-
vb = value * beta.unsqueeze(-1)
|
| 3904 |
-
kb = key * beta.unsqueeze(-1)
|
| 3905 |
-
query, key, value, kb, vb = [
|
| 3906 |
-
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
| 3907 |
-
for x in (query, key, value, kb, vb)
|
| 3908 |
-
]
|
| 3909 |
-
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
| 3910 |
-
triu = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 0)
|
| 3911 |
-
g = g.cumsum(-1)
|
| 3912 |
-
dm = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp()).tril()
|
| 3913 |
-
attn = -((kb @ key.transpose(-1, -2)) * dm).masked_fill(triu, 0)
|
| 3914 |
-
for i in range(1, chunk_size):
|
| 3915 |
-
r = attn[..., i, :i].clone(); s = attn[..., :i, :i].clone()
|
| 3916 |
-
attn[..., i, :i] = r + (r.unsqueeze(-1) * s).sum(-2)
|
| 3917 |
-
eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
| 3918 |
-
attn = attn + eye
|
| 3919 |
-
value = attn @ vb
|
| 3920 |
-
kcd = attn @ (kb * g.exp().unsqueeze(-1))
|
| 3921 |
-
st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
|
| 3922 |
-
out = torch.zeros_like(value)
|
| 3923 |
-
triu2 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1)
|
| 3924 |
-
for i in range(tot // chunk_size):
|
| 3925 |
-
qi, ki, vi = query[:,:,i], key[:,:,i], value[:,:,i]
|
| 3926 |
-
a = (qi @ ki.transpose(-1,-2) * dm[:,:,i]).masked_fill_(triu2, 0)
|
| 3927 |
-
vp = kcd[:,:,i] @ st
|
| 3928 |
-
vn = vi - vp
|
| 3929 |
-
out[:,:,i] = (qi * g[:,:,i,:,None].exp()) @ st + a @ vn
|
| 3930 |
-
st = st * g[:,:,i,-1,None,None].exp() + (ki * (g[:,:,i,-1,None]-g[:,:,i]).exp()[...,None]).transpose(-1,-2) @ vn
|
| 3931 |
-
if not output_final_state: st = None
|
| 3932 |
-
out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1])[:,:,:nh]
|
| 3933 |
-
return out.transpose(1,2).contiguous().to(initial_dtype), st
|
| 3934 |
-
|
| 3935 |
-
|
| 3936 |
-
def _torch_recurrent_gated_delta_rule(
|
| 3937 |
-
query, key, value, g, beta, initial_state, output_final_state,
|
| 3938 |
-
use_qk_l2norm_in_kernel=False,
|
| 3939 |
-
):
|
| 3940 |
-
initial_dtype = query.dtype
|
| 3941 |
-
if use_qk_l2norm_in_kernel:
|
| 3942 |
-
query, key = _l2norm(query), _l2norm(key)
|
| 3943 |
-
query, key, value, beta, g = [
|
| 3944 |
-
x.transpose(1,2).contiguous().to(torch.float32)
|
| 3945 |
-
for x in (query, key, value, beta, g)
|
| 3946 |
-
]
|
| 3947 |
-
bs, seq, nh, kdim = key.shape
|
| 3948 |
-
vdim = value.shape[-1]
|
| 3949 |
-
query = query * (query.shape[-1] ** -0.5)
|
| 3950 |
-
out = torch.zeros(bs, seq, nh, vdim, dtype=value.dtype, device=value.device)
|
| 3951 |
-
st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
|
| 3952 |
-
for i in range(nh):
|
| 3953 |
-
qt, kt, vt = query[:,:,i], key[:,:,i], value[:,:,i]
|
| 3954 |
-
gt, bt = g[:,:,i].exp().unsqueeze(-1).unsqueeze(-1), beta[:,:,i].unsqueeze(-1)
|
| 3955 |
-
st = st * gt
|
| 3956 |
-
delta = (vt - (st * kt.unsqueeze(-1)).sum(-2)) * bt
|
| 3957 |
-
st = st + kt.unsqueeze(-1) * delta.unsqueeze(-2)
|
| 3958 |
-
out[:,:,i] = (st * qt.unsqueeze(-1)).sum(-2)
|
| 3959 |
-
if not output_final_state: st = None
|
| 3960 |
-
return out.transpose(1,2).contiguous().to(initial_dtype), st
|
| 3961 |
-
|
| 3962 |
-
|
| 3963 |
-
class _NeoLLMRMSNormGated(nn.Module):
|
| 3964 |
-
"""Gated RMSNorm fallback when FLA unavailable."""
|
| 3965 |
-
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
| 3966 |
-
super().__init__()
|
| 3967 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 3968 |
-
self.eps = eps
|
| 3969 |
-
def forward(self, x, gate):
|
| 3970 |
-
dtype = x.dtype
|
| 3971 |
-
x = x.float()
|
| 3972 |
-
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 3973 |
-
return (self.weight * x.to(dtype) * F.silu(gate.float())).to(dtype)
|
| 3974 |
-
|
| 3975 |
-
|
| 3976 |
-
class NeoLLMGatedDeltaNet(nn.Module):
|
| 3977 |
-
"""
|
| 3978 |
-
GatedDeltaNet linear attention with FANformer integration.
|
| 3979 |
-
|
| 3980 |
-
Replaces NeoLLMAttention on every ``linear_attention_every_n``-th layer
|
| 3981 |
-
(0-indexed: layers 2, 5, 8 β¦ for every_n=3).
|
| 3982 |
-
|
| 3983 |
-
REPO (use_repo=True AND use_repo_in_linear_attn=True):
|
| 3984 |
-
Applies continuous per-head positions to Q and K via _apply_repo_rope,
|
| 3985 |
-
matching the full-attention REPO path identically.
|
| 3986 |
-
|
| 3987 |
-
Without REPO the gated delta rule operates without explicit positional
|
| 3988 |
-
encoding (its recurrent state is implicitly position-aware).
|
| 3989 |
-
|
| 3990 |
-
References:
|
| 3991 |
-
Yang et al. (2024). arXiv:2412.06464.
|
| 3992 |
-
Li et al. (2026). arXiv:2512.14391.
|
| 3993 |
-
"""
|
| 3994 |
-
|
| 3995 |
-
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 3996 |
-
super().__init__()
|
| 3997 |
-
self.hidden_size = config.hidden_size
|
| 3998 |
-
self.num_v_heads = config.linear_num_value_heads
|
| 3999 |
-
self.num_k_heads = config.linear_num_key_heads
|
| 4000 |
-
self.head_k_dim = config.linear_key_head_dim
|
| 4001 |
-
self.head_v_dim = config.linear_value_head_dim
|
| 4002 |
-
self.key_dim = self.head_k_dim * self.num_k_heads
|
| 4003 |
-
self.value_dim = self.head_v_dim * self.num_v_heads
|
| 4004 |
-
self.conv_kernel_size = config.linear_conv_kernel_dim
|
| 4005 |
-
self.layer_idx = layer_idx
|
| 4006 |
-
|
| 4007 |
-
# ββ FANformer (same ratio as full-attention layers) ββββββββββββββββ
|
| 4008 |
-
self.fan_layer = FANLayer(
|
| 4009 |
-
hidden_size=config.hidden_size,
|
| 4010 |
-
fan_ratio=getattr(config, "fan_ratio", 0.125),
|
| 4011 |
-
)
|
| 4012 |
-
_fan_dim = config.hidden_size + int(
|
| 4013 |
-
config.hidden_size * getattr(config, "fan_ratio", 0.125)
|
| 4014 |
-
)
|
| 4015 |
-
|
| 4016 |
-
# ββ Causal conv1d on concatenated Q/K/V ββββββββββββββββββββββββββ
|
| 4017 |
-
self.conv_dim = self.key_dim * 2 + self.value_dim
|
| 4018 |
-
self.conv1d = nn.Conv1d(
|
| 4019 |
-
self.conv_dim, self.conv_dim, bias=False,
|
| 4020 |
-
kernel_size=self.conv_kernel_size,
|
| 4021 |
-
groups=self.conv_dim,
|
| 4022 |
-
padding=self.conv_kernel_size - 1,
|
| 4023 |
-
)
|
| 4024 |
-
|
| 4025 |
-
# ββ QKVz + ba projections (all from FAN-transformed features) βββββ
|
| 4026 |
-
_ratio = self.num_v_heads // self.num_k_heads
|
| 4027 |
-
self.in_proj_qkvz = nn.Linear(
|
| 4028 |
-
_fan_dim, self.key_dim * 2 + self.value_dim * 2, bias=False
|
| 4029 |
-
)
|
| 4030 |
-
self.in_proj_ba = nn.Linear(
|
| 4031 |
-
_fan_dim, self.num_v_heads * 2, bias=False
|
| 4032 |
-
)
|
| 4033 |
-
|
| 4034 |
-
# ββ Delta-rule gating parameters ββββββββββββββββββββββββββββββββββ
|
| 4035 |
-
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
|
| 4036 |
-
A = torch.empty(self.num_v_heads).uniform_(0, 16)
|
| 4037 |
-
self.A_log = nn.Parameter(torch.log(A))
|
| 4038 |
-
|
| 4039 |
-
# ββ Output normalisation βββββββββββββββββββββββββββββββββββββββββββ
|
| 4040 |
-
_NormCls = FusedRMSNormGated if FusedRMSNormGated is not None else _NeoLLMRMSNormGated
|
| 4041 |
-
_norm_kw = (
|
| 4042 |
-
dict(activation="silu",
|
| 4043 |
-
device=torch.cuda.current_device(),
|
| 4044 |
-
dtype=getattr(config, "dtype", None) or torch.get_default_dtype())
|
| 4045 |
-
if FusedRMSNormGated is not None else {}
|
| 4046 |
-
)
|
| 4047 |
-
self.norm = _NormCls(self.head_v_dim, eps=config.rms_norm_eps, **_norm_kw)
|
| 4048 |
-
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
| 4049 |
-
self.dropout = nn.Dropout(config.dropout_rate)
|
| 4050 |
-
|
| 4051 |
-
# ββ Kernel dispatch (fast β fallback) βββββββββββββββββββββββββββββ
|
| 4052 |
-
self._conv1d_fn = causal_conv1d_fn # None if not installed
|
| 4053 |
-
self._chunk_fn = (chunk_gated_delta_rule
|
| 4054 |
-
if chunk_gated_delta_rule is not None
|
| 4055 |
-
else _torch_chunk_gated_delta_rule)
|
| 4056 |
-
self._recur_fn = (fused_recurrent_gated_delta_rule
|
| 4057 |
-
if fused_recurrent_gated_delta_rule is not None
|
| 4058 |
-
else _torch_recurrent_gated_delta_rule)
|
| 4059 |
-
|
| 4060 |
-
if not is_linear_attn_fast_path:
|
| 4061 |
-
logger.warning_once(
|
| 4062 |
-
"NeoLLMGatedDeltaNet: causal_conv1d / flash-linear-attention "
|
| 4063 |
-
"not installed β using pure-PyTorch fallbacks. "
|
| 4064 |
-
"Install both packages for full performance."
|
| 4065 |
-
)
|
| 4066 |
-
|
| 4067 |
-
# ββ REPO: continuous per-head positions on Q and K βββββββββββββββββ
|
| 4068 |
-
# Controlled by use_repo AND use_repo_in_linear_attn flags.
|
| 4069 |
-
# Only active for layers at or above repo_start_layer.
|
| 4070 |
-
self.use_repo = (
|
| 4071 |
-
getattr(config, "use_repo", False)
|
| 4072 |
-
and getattr(config, "use_repo_in_linear_attn", False)
|
| 4073 |
-
and layer_idx >= getattr(config, "repo_start_layer",
|
| 4074 |
-
config.num_hidden_layers // 3)
|
| 4075 |
-
)
|
| 4076 |
-
if self.use_repo:
|
| 4077 |
-
_d_p = getattr(config, "repo_d_p", config.hidden_size // 8)
|
| 4078 |
-
self.repo_module = REPOModule(
|
| 4079 |
-
hidden_size=config.hidden_size,
|
| 4080 |
-
d_p=_d_p,
|
| 4081 |
-
num_heads=self.num_v_heads,
|
| 4082 |
-
)
|
| 4083 |
-
else:
|
| 4084 |
-
self.repo_module = None
|
| 4085 |
-
|
| 4086 |
-
def _fix_qkvz(
|
| 4087 |
-
self,
|
| 4088 |
-
mixed_qkvz: torch.Tensor,
|
| 4089 |
-
mixed_ba: torch.Tensor,
|
| 4090 |
-
) -> Tuple[torch.Tensor, ...]:
|
| 4091 |
-
"""Split fused projection into (q, k, v, z, b, a)."""
|
| 4092 |
-
ratio = self.num_v_heads // self.num_k_heads
|
| 4093 |
-
mixed_qkvz = mixed_qkvz.view(
|
| 4094 |
-
*mixed_qkvz.shape[:-1],
|
| 4095 |
-
self.num_k_heads,
|
| 4096 |
-
2 * self.head_k_dim + 2 * ratio * self.head_v_dim,
|
| 4097 |
-
)
|
| 4098 |
-
mixed_ba = mixed_ba.view(
|
| 4099 |
-
*mixed_ba.shape[:-1],
|
| 4100 |
-
self.num_k_heads,
|
| 4101 |
-
2 * ratio,
|
| 4102 |
-
)
|
| 4103 |
-
q, k, v, z = torch.split(
|
| 4104 |
-
mixed_qkvz,
|
| 4105 |
-
[self.head_k_dim, self.head_k_dim,
|
| 4106 |
-
ratio * self.head_v_dim, ratio * self.head_v_dim],
|
| 4107 |
-
dim=3,
|
| 4108 |
-
)
|
| 4109 |
-
b, a = torch.split(mixed_ba, ratio, dim=3)
|
| 4110 |
-
v = v.reshape(v.shape[0], v.shape[1], -1, self.head_v_dim)
|
| 4111 |
-
z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim)
|
| 4112 |
-
b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads)
|
| 4113 |
-
a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads)
|
| 4114 |
-
return q, k, v, z, b, a
|
| 4115 |
-
|
| 4116 |
-
def forward(
|
| 4117 |
-
self,
|
| 4118 |
-
hidden_states: torch.Tensor,
|
| 4119 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 4120 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 4121 |
-
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
|
| 4122 |
-
) -> torch.Tensor:
|
| 4123 |
-
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 4124 |
-
B, S, _ = hidden_states.shape
|
| 4125 |
-
|
| 4126 |
-
# ββ FANformer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4127 |
-
h_fan = self.fan_layer(hidden_states)
|
| 4128 |
-
|
| 4129 |
-
# ββ QKVz and ba projections βββββββββββββββββββββββββββββββββββββββ
|
| 4130 |
-
q, k, v, z, b, a = self._fix_qkvz(
|
| 4131 |
-
self.in_proj_qkvz(h_fan), self.in_proj_ba(h_fan)
|
| 4132 |
-
)
|
| 4133 |
-
|
| 4134 |
-
# ββ Causal conv1d on flattened Q/K/V βββββββββββββββββββββββββββββ
|
| 4135 |
-
qkv = torch.cat(
|
| 4136 |
-
[q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], dim=-1
|
| 4137 |
-
).transpose(1, 2) # [B, conv_dim, S]
|
| 4138 |
-
|
| 4139 |
-
if self._conv1d_fn is not None:
|
| 4140 |
-
qkv = self._conv1d_fn(
|
| 4141 |
-
x=qkv, weight=self.conv1d.weight.squeeze(1),
|
| 4142 |
-
bias=self.conv1d.bias, activation="silu", seq_idx=None,
|
| 4143 |
-
)
|
| 4144 |
-
else:
|
| 4145 |
-
qkv = F.silu(self.conv1d(qkv)[:, :, :S])
|
| 4146 |
-
qkv = qkv.transpose(1, 2) # [B, S, conv_dim]
|
| 4147 |
-
|
| 4148 |
-
q_f, k_f, v_f = torch.split(
|
| 4149 |
-
qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1
|
| 4150 |
-
)
|
| 4151 |
-
q = q_f.reshape(B, S, -1, self.head_k_dim)
|
| 4152 |
-
k = k_f.reshape(B, S, -1, self.head_k_dim)
|
| 4153 |
-
v = v_f.reshape(B, S, -1, self.head_v_dim)
|
| 4154 |
-
|
| 4155 |
-
# ββ REPO: continuous per-head positions βββββββββββββββββββββββββββ
|
| 4156 |
-
# Transpose to [B, H, S, dk] for _apply_repo_rope, then back.
|
| 4157 |
-
if self.use_repo and self.repo_module is not None and repo_rope_args is not None:
|
| 4158 |
-
inv_freq, attn_scaling = repo_rope_args
|
| 4159 |
-
z_pos = self.repo_module(hidden_states) # [B, H, S]
|
| 4160 |
-
q_t, k_t = q.transpose(1, 2), k.transpose(1, 2)
|
| 4161 |
-
q_t, k_t = _apply_repo_rope(q_t, k_t, z_pos, inv_freq, attn_scaling)
|
| 4162 |
-
q, k = q_t.transpose(1, 2), k_t.transpose(1, 2)
|
| 4163 |
-
|
| 4164 |
-
# ββ GQA-like head expansion βββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββ
|
| 4165 |
-
ratio = self.num_v_heads // self.num_k_heads
|
| 4166 |
-
if ratio > 1:
|
| 4167 |
-
q = q.repeat_interleave(ratio, dim=2)
|
| 4168 |
-
k = k.repeat_interleave(ratio, dim=2)
|
| 4169 |
-
|
| 4170 |
-
beta = b.sigmoid()
|
| 4171 |
-
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
| 4172 |
-
|
| 4173 |
-
# ββ Chunk gated delta rule (fused or fallback) ββββββββββββββββββββ
|
| 4174 |
-
core_out, _ = self._chunk_fn(
|
| 4175 |
-
q, k, v, g=g, beta=beta,
|
| 4176 |
-
initial_state=None, output_final_state=False,
|
| 4177 |
-
use_qk_l2norm_in_kernel=True,
|
| 4178 |
-
)
|
| 4179 |
-
|
| 4180 |
-
# ββ Gated RMSNorm + output projection βββββββββββββββββββββββββββββ
|
| 4181 |
-
z_shape = z.shape
|
| 4182 |
-
core_out = core_out.reshape(-1, core_out.shape[-1])
|
| 4183 |
-
core_out = self.norm(core_out, z.reshape(-1, z.shape[-1]))
|
| 4184 |
-
core_out = core_out.reshape(z_shape).reshape(B, S, -1)
|
| 4185 |
-
return self.dropout(self.out_proj(core_out))
|
| 4186 |
-
|
| 4187 |
-
|
| 4188 |
-
# ==================== DECODER LAYER ==========================================
|
| 4189 |
-
|
| 4190 |
-
|
| 4191 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 4192 |
"""
|
| 4193 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
@@ -4210,23 +3835,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4210 |
self.layer_idx = layer_idx
|
| 4211 |
self.use_jtokm = config.use_jtokm
|
| 4212 |
|
| 4213 |
-
|
| 4214 |
-
# use_linear_attention=True: replace full attention every
|
| 4215 |
-
# `linear_attention_every_n` layers (0-indexed pattern:
|
| 4216 |
-
# e.g. every_n=3 β layers 2, 5, 8, 11 β¦).
|
| 4217 |
-
# All other layers keep NeoLLMAttention unchanged.
|
| 4218 |
-
_every_n = getattr(config, "linear_attention_every_n", 3)
|
| 4219 |
-
self.is_linear_attn = (
|
| 4220 |
-
getattr(config, "use_linear_attention", False)
|
| 4221 |
-
and (layer_idx + 1) % _every_n == 0
|
| 4222 |
-
)
|
| 4223 |
-
if self.is_linear_attn:
|
| 4224 |
-
self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
|
| 4225 |
-
self.self_attn = None
|
| 4226 |
-
else:
|
| 4227 |
-
self.self_attn = NeoLLMAttention(config, layer_idx)
|
| 4228 |
-
self.linear_attn = None
|
| 4229 |
-
|
| 4230 |
self.mlp = (
|
| 4231 |
VersatileFFN(config)
|
| 4232 |
if getattr(config, "use_versatile_ffn", False)
|
|
@@ -4484,32 +4093,17 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 4484 |
if layer_analysis is not None:
|
| 4485 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 4486 |
|
| 4487 |
-
|
| 4488 |
-
|
| 4489 |
-
|
| 4490 |
-
|
| 4491 |
-
|
| 4492 |
-
|
| 4493 |
-
|
| 4494 |
-
|
| 4495 |
-
|
| 4496 |
-
|
| 4497 |
-
|
| 4498 |
-
attn_weights = None
|
| 4499 |
-
self.current_layer_fan = None
|
| 4500 |
-
else:
|
| 4501 |
-
# ββ Standard full attention path ββββββββββββββββββββββββββββββ
|
| 4502 |
-
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
| 4503 |
-
hidden_states=h_lns,
|
| 4504 |
-
attention_mask=attention_mask,
|
| 4505 |
-
position_embeddings=position_embeddings,
|
| 4506 |
-
first_layer_fan=first_layer_fan,
|
| 4507 |
-
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 4508 |
-
repo_rope_args=repo_rope_args,
|
| 4509 |
-
mudd_xk=mudd_xk,
|
| 4510 |
-
mudd_xv=mudd_xv,
|
| 4511 |
-
**kwargs,
|
| 4512 |
-
)
|
| 4513 |
|
| 4514 |
if layer_analysis is not None:
|
| 4515 |
layer_analysis.attn_contribution = hidden_states.detach()
|
|
@@ -4933,9 +4527,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 4933 |
if hasattr(module, "alpha_ma"):
|
| 4934 |
module.alpha_ma.zero_()
|
| 4935 |
|
| 4936 |
-
elif isinstance(module, NeoLLMGatedDeltaNet):
|
| 4937 |
-
module.dt_bias.data.fill_(1.0)
|
| 4938 |
-
module.A_log.data.uniform_(0, 16).log_()
|
| 4939 |
elif isinstance(module, GPAS):
|
| 4940 |
module.alpha.data.fill_(0.0)
|
| 4941 |
|
|
@@ -5178,21 +4769,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5178 |
|
| 5179 |
self.post_init()
|
| 5180 |
|
| 5181 |
-
def _update_linear_attn_mask(
|
| 5182 |
-
self,
|
| 5183 |
-
attention_mask: Optional[torch.Tensor],
|
| 5184 |
-
) -> Optional[torch.Tensor]:
|
| 5185 |
-
"""
|
| 5186 |
-
Return mask for GatedDeltaNet layers (no causal bias, padding only).
|
| 5187 |
-
Returns None when all tokens are valid (GatedDeltaNet handles via
|
| 5188 |
-
_apply_mask_to_padding_states internally).
|
| 5189 |
-
"""
|
| 5190 |
-
if attention_mask is None:
|
| 5191 |
-
return None
|
| 5192 |
-
if torch.all(attention_mask == 1):
|
| 5193 |
-
return None
|
| 5194 |
-
return attention_mask
|
| 5195 |
-
|
| 5196 |
def get_input_embeddings(self):
|
| 5197 |
if self.config.use_token_generator:
|
| 5198 |
return self.token_generator
|
|
@@ -5343,10 +4919,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5343 |
if getattr(self.config, "use_repo", False) else None
|
| 5344 |
)
|
| 5345 |
|
| 5346 |
-
# ββ Linear attention mask ββββββββββββββββββββββββββββββββββββββββββ
|
| 5347 |
-
# Computed once; each layer picks the appropriate mask below.
|
| 5348 |
-
linear_attn_mask = self._update_linear_attn_mask(attention_mask)
|
| 5349 |
-
|
| 5350 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 5351 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 5352 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
@@ -5436,17 +5008,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 5436 |
layer_analysis.layer_idx = layer_idx
|
| 5437 |
analysis_state.layers.append(layer_analysis)
|
| 5438 |
|
| 5439 |
-
# Select the appropriate mask: causal for full attention,
|
| 5440 |
-
# padding-only for GatedDeltaNet linear attention.
|
| 5441 |
-
_layer_mask = (
|
| 5442 |
-
linear_attn_mask
|
| 5443 |
-
if getattr(decoder_layer, "is_linear_attn", False)
|
| 5444 |
-
else causal_mask
|
| 5445 |
-
)
|
| 5446 |
layer_outputs = decoder_layer(
|
| 5447 |
hidden_states,
|
| 5448 |
position_embeddings=position_embeddings,
|
| 5449 |
-
attention_mask=
|
| 5450 |
first_layer_fan=self.first_layer_fan,
|
| 5451 |
z_tilde=z_tilde,
|
| 5452 |
B_vals=B_vals,
|
|
@@ -5843,7 +5408,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 5843 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 5844 |
|
| 5845 |
__all__ = [
|
| 5846 |
-
"NeoLLMGatedDeltaNet",
|
| 5847 |
"StackMemory",
|
| 5848 |
"LAuReLLayer",
|
| 5849 |
"NeoLLMMUDDModule",
|
|
|
|
| 81 |
|
| 82 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
logger = logging.get_logger(__name__)
|
| 85 |
|
| 86 |
|
|
|
|
| 404 |
residual_scale: Optional[float] = None # res_weight scalar
|
| 405 |
|
| 406 |
|
| 407 |
+
@dataclass
|
| 408 |
+
class LAuReLAnalysis:
|
| 409 |
+
"""
|
| 410 |
+
Internals of one LAuReL residual connection forward pass.
|
| 411 |
+
Only populated when use_laurel=True AND model is in eval + analysis mode.
|
| 412 |
+
Instantiated twice per layer: once for the attention residual, once for MLP.
|
| 413 |
+
|
| 414 |
+
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 415 |
+
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 416 |
+
|
| 417 |
+
Math (combined RW+LR, both sub-variants active):
|
| 418 |
+
|
| 419 |
+
x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
|
| 420 |
+
|
| 421 |
+
where [Ξ±, Ξ²] = softmax([a, b]), a,b β β learnable (RW component),
|
| 422 |
+
B β β^{rΓD} column-orthogonal init, A β β^{DΓr} zero init (LR component).
|
| 423 |
+
At step 0: A=0 β lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
|
| 424 |
+
or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
|
| 425 |
+
|
| 426 |
+
Fields:
|
| 427 |
+
alpha_rw: softmax(a) β weight on f(x_i). [scalar float]
|
| 428 |
+
None when use_laurel_rw=False.
|
| 429 |
+
beta_rw: softmax(b) β weight on g(x_i). [scalar float]
|
| 430 |
+
None when use_laurel_rw=False.
|
| 431 |
+
lr_term: AΒ·(BΒ·x_res) β the low-rank residual augmentation.
|
| 432 |
+
Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
|
| 433 |
+
output: Final combined tensor before GPAS. Shape [B, S, D].
|
| 434 |
+
"""
|
| 435 |
+
alpha_rw: Optional[float] = None # softmax weight on f(x)
|
| 436 |
+
beta_rw: Optional[float] = None # softmax weight on g(x)
|
| 437 |
+
lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
|
| 438 |
+
output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
|
| 439 |
+
|
| 440 |
+
|
| 441 |
@dataclass
|
| 442 |
class LayerAnalysis:
|
| 443 |
"""
|
|
|
|
| 470 |
attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
|
| 471 |
dca: Optional[DCAAnalysis] = None # if use_dca
|
| 472 |
stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
|
| 473 |
+
laurel_attn: Optional[LAuReLAnalysis] = None # if use_laurel (attention residual)
|
| 474 |
+
laurel_mlp: Optional[LAuReLAnalysis] = None # if use_laurel (MLP residual)
|
| 475 |
|
| 476 |
|
| 477 |
@dataclass
|
|
|
|
| 3695 |
|
| 3696 |
|
| 3697 |
@dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3698 |
class LAuReLLayer(nn.Module):
|
| 3699 |
"""
|
| 3700 |
LAuReL: Learned Augmented Residual Layer.
|
|
|
|
| 3813 |
return out
|
| 3814 |
|
| 3815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3816 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 3817 |
"""
|
| 3818 |
Decoder layer with standard residual connections, optional JTok-M injection.
|
|
|
|
| 3835 |
self.layer_idx = layer_idx
|
| 3836 |
self.use_jtokm = config.use_jtokm
|
| 3837 |
|
| 3838 |
+
self.self_attn = NeoLLMAttention(config, layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3839 |
self.mlp = (
|
| 3840 |
VersatileFFN(config)
|
| 3841 |
if getattr(config, "use_versatile_ffn", False)
|
|
|
|
| 4093 |
if layer_analysis is not None:
|
| 4094 |
layer_analysis.lns_attn_output = h_lns.detach()
|
| 4095 |
|
| 4096 |
+
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
| 4097 |
+
hidden_states=h_lns,
|
| 4098 |
+
attention_mask=attention_mask,
|
| 4099 |
+
position_embeddings=position_embeddings,
|
| 4100 |
+
first_layer_fan=first_layer_fan,
|
| 4101 |
+
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
|
| 4102 |
+
repo_rope_args=repo_rope_args,
|
| 4103 |
+
mudd_xk=mudd_xk,
|
| 4104 |
+
mudd_xv=mudd_xv,
|
| 4105 |
+
**kwargs,
|
| 4106 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4107 |
|
| 4108 |
if layer_analysis is not None:
|
| 4109 |
layer_analysis.attn_contribution = hidden_states.detach()
|
|
|
|
| 4527 |
if hasattr(module, "alpha_ma"):
|
| 4528 |
module.alpha_ma.zero_()
|
| 4529 |
|
|
|
|
|
|
|
|
|
|
| 4530 |
elif isinstance(module, GPAS):
|
| 4531 |
module.alpha.data.fill_(0.0)
|
| 4532 |
|
|
|
|
| 4769 |
|
| 4770 |
self.post_init()
|
| 4771 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4772 |
def get_input_embeddings(self):
|
| 4773 |
if self.config.use_token_generator:
|
| 4774 |
return self.token_generator
|
|
|
|
| 4919 |
if getattr(self.config, "use_repo", False) else None
|
| 4920 |
)
|
| 4921 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4922 |
# ββ Attention Residuals state ββββββββββββββββββββββββββββββββββββββ
|
| 4923 |
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
|
| 4924 |
# decoder layer β all previous outputs are kept, max N=num_layers+1.
|
|
|
|
| 5008 |
layer_analysis.layer_idx = layer_idx
|
| 5009 |
analysis_state.layers.append(layer_analysis)
|
| 5010 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5011 |
layer_outputs = decoder_layer(
|
| 5012 |
hidden_states,
|
| 5013 |
position_embeddings=position_embeddings,
|
| 5014 |
+
attention_mask=causal_mask,
|
| 5015 |
first_layer_fan=self.first_layer_fan,
|
| 5016 |
z_tilde=z_tilde,
|
| 5017 |
B_vals=B_vals,
|
|
|
|
| 5408 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 5409 |
|
| 5410 |
__all__ = [
|
|
|
|
| 5411 |
"StackMemory",
|
| 5412 |
"LAuReLLayer",
|
| 5413 |
"NeoLLMMUDDModule",
|