File size: 14,246 Bytes
4303959 3558023 4303959 622dbbd 4303959 622dbbd 4303959 883d17e 622dbbd 4303959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
# Remote code: configuration and modeling for NSA
import math
from typing import Optional
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
from .configuration_nsa import NSAConfig
_HAS_NSA = False # Do not attempt nested vendor import in HF dynamic loader
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
return (x * rms) * self.weight
class MLP(nn.Module):
def __init__(self, dim: int, hidden_mult: int = 4) -> None:
super().__init__()
h = hidden_mult * dim
self.fc1 = nn.Linear(dim, h, bias=False)
self.fc2 = nn.Linear(h, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(torch.nn.functional.silu(self.fc1(x)))
def _rope(q: torch.Tensor) -> torch.Tensor:
B, S, D = q.shape[0], q.shape[2], q.shape[-1]
if D % 2 != 0:
return q
device = q.device
half = D // 2
pos = torch.arange(S, device=device).float().unsqueeze(-1)
inv_freq = 1.0 / (10000 ** (torch.arange(0, half, device=device).float() / half))
angles = pos * inv_freq
cos = angles.cos().view(1, 1, S, half)
sin = angles.sin().view(1, 1, S, half)
q1, q2 = q[..., :half], q[..., half:]
return torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
def _avg_pool_time(x: torch.Tensor, kernel: int, stride: int) -> torch.Tensor:
if x.shape[2] < kernel:
return x[..., :0, :]
xt = x.permute(0, 3, 1, 2).contiguous()
y = torch.nn.functional.avg_pool2d(xt, kernel_size=(1, kernel), stride=(1, stride))
return y.permute(0, 2, 3, 1).contiguous()
def _window_mask(q: torch.Tensor, S: int, w: int) -> torch.Tensor:
B, h = q.shape[0], q.shape[1]
device = q.device
row = torch.arange(S, device=device).view(S, 1)
col = torch.arange(S, device=device).view(1, S)
allowed = (col <= row) & (col >= (row - (w - 1)))
M = torch.full((S, S), float('-inf'), device=device, dtype=q.dtype)
M.masked_fill_(allowed, 0.0)
return M.view(1, 1, S, S).expand(B, h, S, S)
def _selection_blocks(scores: torch.Tensor, l_sel: int, n_sel: int) -> torch.Tensor:
B, h, S = scores.shape
n_blocks = max(1, (S + l_sel - 1) // l_sel)
# Pad to multiple of l_sel
pad = n_blocks * l_sel - S
if pad > 0:
scores = torch.nn.functional.pad(scores, (0, pad), value=-1e9)
blk_scores = scores.view(B, h, n_blocks, l_sel).max(dim=-1).values
k = min(n_sel, n_blocks)
return torch.topk(blk_scores, k=k, dim=-1).indices
class EmbeddedNSAAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
super().__init__()
self.n_heads = n_heads
self.n_kv_groups = n_kv_groups
self.d_k = d_k
self.d_v = d_v
self.l = l
self.stride = d
self.l_sel = l_sel
self.n_sel = n_sel
self.w = w
self.W_Q = nn.Linear(dim, n_heads * d_k, bias=False)
self.W_K_cmp = nn.Linear(dim, n_kv_groups * d_k, bias=False)
self.W_V_cmp = nn.Linear(dim, n_kv_groups * d_v, bias=False)
self.W_K_sel = nn.Linear(dim, n_kv_groups * d_k, bias=False)
self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
# Gate MLP operates on per-group pooled Q with width d_k (matches training)
gate_hidden = max(1, d_k // 2)
self.gate_fc1 = nn.Linear(d_k, gate_hidden, bias=True)
self.gate_fc2 = nn.Linear(gate_hidden, 3, bias=True)
nn.init.xavier_uniform_(self.gate_fc2.weight, gain=0.1)
nn.init.zeros_(self.gate_fc2.bias)
self.out = nn.Linear(n_heads * d_v, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, D = x.shape
h, dk, dv = self.n_heads, self.d_k, self.d_v
Q = self.W_Q(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
g = max(1, self.n_kv_groups)
r = max(1, h // g)
# Project per-group K/V then broadcast to heads
Kc_g = self.W_K_cmp(x).view(B, S, g, dk).permute(0, 2, 1, 3) # [B,g,S,dk]
Vc_g = self.W_V_cmp(x).view(B, S, g, dv).permute(0, 2, 1, 3)
Ks_g = self.W_K_sel(x).view(B, S, g, dk).permute(0, 2, 1, 3)
Vs_g = self.W_V_sel(x).view(B, S, g, dv).permute(0, 2, 1, 3)
Kw_g = self.W_K_win(x).view(B, S, g, dk).permute(0, 2, 1, 3)
Vw_g = self.W_V_win(x).view(B, S, g, dv).permute(0, 2, 1, 3)
# Broadcast groups to heads
def _bcast_to_heads(T):
return T.unsqueeze(1).expand(B, r, g, S, T.shape[-1]).reshape(B, h, S, T.shape[-1])
Kc = _bcast_to_heads(Kc_g)
Vc = _bcast_to_heads(Vc_g)
Ks = _bcast_to_heads(Ks_g)
Vs = _bcast_to_heads(Vs_g)
Kw = _bcast_to_heads(Kw_g)
Vw = _bcast_to_heads(Vw_g)
# RoPE
Qr = _rope(Q.transpose(1, 2)).transpose(1, 2)
Kc_r = _rope(Kc.transpose(1, 2)).transpose(1, 2)
Ks_r = _rope(Ks.transpose(1, 2)).transpose(1, 2)
Kw_r = _rope(Kw.transpose(1, 2)).transpose(1, 2)
# Compressed: average-pool along time
Kc_p = _avg_pool_time(Kc_r, kernel=max(1, self.stride), stride=max(1, self.stride))
Vc_p = _avg_pool_time(Vc, kernel=max(1, self.stride), stride=max(1, self.stride))
O_cmp = torch.nn.functional.scaled_dot_product_attention(Qr, Kc_p, Vc_p, is_causal=True)
# Selection: naive top-n blocks (global), enforce causal via triangular mask
scores = (Qr * Ks_r).mean(dim=-1) # [B,h,S]
blk_idx = _selection_blocks(scores, self.l_sel, self.n_sel) # [B,h,n]
n_blocks = max(1, (S + self.l_sel - 1) // self.l_sel)
keep = torch.zeros((B, h, n_blocks), device=x.device, dtype=torch.bool)
keep.scatter_(2, blk_idx, True)
keep = keep.unsqueeze(-1).expand(B, h, n_blocks, self.l_sel).reshape(B, h, -1)[:, :, :S]
logits = torch.matmul(Qr / math.sqrt(dk), Ks_r.transpose(-2, -1)) # [B,h,S,S]
tri = torch.triu(torch.ones((S, S), device=x.device, dtype=torch.bool), diagonal=1)
logits = logits.masked_fill(tri, float('-inf'))
sel_mask = torch.where(keep.unsqueeze(2).expand(B, h, S, S), torch.zeros((), device=x.device, dtype=Qr.dtype), torch.full((), float('-inf'), device=x.device, dtype=Qr.dtype))
P = torch.nn.functional.softmax(logits + sel_mask, dim=-1)
O_sel = torch.matmul(P, Vs)
# Sliding window
M = _window_mask(Qr, S, max(1, self.w))
logits_w = torch.matmul(Qr / math.sqrt(dk), Kw_r.transpose(-2, -1)) + M
P_w = torch.nn.functional.softmax(logits_w, dim=-1)
O_win = torch.matmul(P_w, Vw)
# Gate & mix: compute per-token, per-group gate from pooled Q
# Pool Q across heads within each kv-group
# Qr: [B,h,S,dk] -> reshape to [B,G,h_per_group,S,dk] then mean over h_per_group
G = max(1, self.n_kv_groups)
h_per_group = max(1, h // G)
Qg = Qr.view(B, G, h_per_group, S, dk).mean(dim=2) # [B,G,S,dk]
Qg = Qg.permute(0, 2, 1, 3) # [B,S,G,dk]
g1 = torch.nn.functional.silu(self.gate_fc1(Qg))
gate = torch.nn.functional.softmax(self.gate_fc2(g1), dim=-1) # [B,S,G,3]
gc = gate[..., 0:1].unsqueeze(-1) # [B,S,G,1,1]
gs = gate[..., 1:2].unsqueeze(-1)
gw = gate[..., 2:3].unsqueeze(-1)
# Broadcast group gates to heads within the group
# Reshape branch outputs to [B,S,G,h_per_group,dv]
Oc = O_cmp.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
Os = O_sel.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
Ow = O_win.permute(0,2,1,3).view(B, S, G, h_per_group, dv)
O = gc * Oc + gs * Os + gw * Ow
O = O.reshape(B, S, h, dv).permute(0, 2, 1, 3)
O = O.transpose(1, 2).reshape(B, S, h * dv)
return self.out(O)
class SimpleAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
super().__init__()
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.q_proj = nn.Linear(dim, n_heads * d_k, bias=False)
self.k_proj = nn.Linear(dim, n_heads * d_k, bias=False)
self.v_proj = nn.Linear(dim, n_heads * d_v, bias=False)
self.out = nn.Linear(n_heads * d_v, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, D = x.shape
h, dk, dv = self.n_heads, self.d_k, self.d_v
q = self.q_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
k = self.k_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
v = self.v_proj(x).view(B, S, h, dv).transpose(1, 2) # [B,h,S,dv]
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
attn = attn.transpose(1, 2).contiguous().view(B, S, h * dv)
return self.out(attn)
class SimpleBlock(nn.Module):
def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = SimpleAttention(dim, n_heads, d_k, d_v)
self.norm2 = RMSNorm(dim)
self.mlp = MLP(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class NSABlockRemote(nn.Module):
"""Transformer block with embedded NSA attention, pre/post RMSNorm, and MLP."""
def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = EmbeddedNSAAttention(dim, n_heads, n_kv_groups, d_k, d_v, l, d, l_sel, n_sel, w)
self.norm2 = RMSNorm(dim)
self.mlp = MLP(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class NSATinyLM(nn.Module):
def __init__(self, config: NSAConfig):
super().__init__()
self.config = config
self.vocab_size = int(config.vocab_size)
self.hidden_size = int(config.hidden_size)
self.num_hidden_layers = int(config.num_hidden_layers)
self.num_attention_heads = int(config.num_attention_heads)
self.n_kv_groups = int(getattr(config, "n_kv_groups", 1))
self.d_k = int(getattr(config, "d_k", self.hidden_size // self.num_attention_heads))
self.d_v = int(getattr(config, "d_v", self.hidden_size // self.num_attention_heads))
nsa = config.nsa or {}
self.l = int(nsa.get("block", 32))
self.d = int(nsa.get("stride", 16))
self.l_sel = int(nsa.get("sel_block", 64))
self.n_sel = int(nsa.get("sel_top_n", 16))
self.w = int(nsa.get("window", 512))
self.embed = nn.Embedding(self.vocab_size, self.hidden_size)
import os as _os
# Allow forcing simple fallback via env for integration tests
_force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
if not _force_simple:
# Fallback to embedded minimal NSA if vendor import failed
self.blocks = nn.ModuleList([
NSABlockRemote(
self.hidden_size,
self.num_attention_heads,
self.n_kv_groups,
self.d_k,
self.d_v,
self.l,
self.d,
self.l_sel,
self.n_sel,
self.w,
) for _ in range(self.num_hidden_layers)
])
else:
self.blocks = nn.ModuleList([
SimpleBlock(self.hidden_size, self.num_attention_heads, self.d_k, self.d_v)
for _ in range(self.num_hidden_layers)
])
self.norm = nn.LayerNorm(self.hidden_size)
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embed(input_ids)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
logits = self.lm_head(x)
return logits
class NSAForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NSAConfig
_no_split_modules = ["EmbeddedNSAAttention", "SimpleBlock"]
def __init__(self, config: NSAConfig):
super().__init__(config)
self.model = NSATinyLM(config)
self.post_init()
def get_input_embeddings(self):
return self.model.embed
def set_input_embeddings(self, new_emb):
self.model.embed = new_emb
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
):
if input_ids is None:
raise ValueError("input_ids is required")
logits = self.model(input_ids)
loss = None
if labels is not None:
# Shift for causal LM loss
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutput(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# No past_key_values cache: rerun full sequence. Works everywhere, slower at decode.
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|