nsa-117m-byte / nsa /core /attention_kernels.py
seconds-0's picture
NSA 117M initial export
4303959 verified
from __future__ import annotations
import os
import time
from typing import Dict, Tuple
import torch
import torch.nn.functional as F
from nsa.core.debug import log
from nsa.core.packing import (
build_cu_seqlens_for_buckets,
build_length_buckets,
compute_compressed_lengths,
compute_sliding_lengths,
)
from nsa.kernels.flash_wrappers import (
attention_bgh,
attention_fa2_dense_batch,
attention_fa2_varlen,
fa2_supported,
fa2_supported_verbose,
is_flash_varlen_available,
)
# Simple grow-on-demand workspaces for varlen packing to avoid frequent allocations
_VARLEN_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
_SEL_PACK_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
def _env_int(name: str, default: int) -> int:
try:
v = int(os.getenv(name, str(default)))
return v
except Exception:
return default
def _env_int_bounded(name: str, default: int, min_val: int = 0, max_val: int = 10**8) -> int:
"""Read integer from environment with bounds checking to prevent excessive memory allocation."""
try:
v = int(os.getenv(name, str(default)))
if v < min_val:
return min_val
if v > max_val:
# Log warning if value exceeds max
import warnings
warnings.warn(f"{name}={v} exceeds maximum {max_val}, clamping to {max_val}")
return max_val
return v
except Exception:
return default
def clear_varlen_workspaces() -> None:
"""Optional memory cleanup: free varlen packing workspaces."""
_VARLEN_WS.clear()
def clear_selection_pack_workspaces() -> None:
"""Optional memory cleanup: free selection pack workspaces."""
_SEL_PACK_WS.clear()
def _get_varlen_workspace(
device: torch.device,
dtype_q: torch.dtype,
dtype_k: torch.dtype,
dtype_v: torch.dtype,
h: int,
d_k: int,
d_v: int,
cap_N: int,
cap_total_k: int,
) -> dict[str, torch.Tensor]:
key = (str(device), dtype_q, dtype_k, dtype_v, h, d_k, d_v)
ws = _VARLEN_WS.get(key)
need_new = ws is None
if not need_new:
q, k, v = ws["q"], ws["k"], ws["v"]
cuq, cuk = ws["cuq"], ws["cuk"]
need_new = (
q.shape[0] < cap_N
or k.shape[0] < cap_total_k
or v.shape[0] < cap_total_k
or cuq.numel() < (cap_N + 1)
or cuk.numel() < (cap_N + 1)
)
if need_new:
# Allow pre-sizing via env to avoid growth reallocations on long runs
# Bounded to prevent excessive memory allocation (max 1M rows, 100M total K/V)
reserve_N = _env_int_bounded("NSA_VARLEN_RESERVE_N", 0, 0, 10**6)
reserve_K = _env_int_bounded("NSA_VARLEN_RESERVE_K", 0, 0, 10**8)
new_N = max(cap_N, reserve_N, 1)
new_K = max(cap_total_k, reserve_K, 1)
ws = {
"q": torch.empty((new_N, h, d_k), dtype=dtype_q, device=device),
"k": torch.empty((new_K, h, d_k), dtype=dtype_k, device=device),
"v": torch.empty((new_K, h, d_v), dtype=dtype_v, device=device),
"cuq": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
"cuk": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
}
_VARLEN_WS[key] = ws
return ws
def batched_causal_attention_compressed(
Q: torch.Tensor, # [B,S,G,h,Dk]
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
l: int,
d: int,
) -> torch.Tensor: # [B,S,G,h,Dv]
"""
Compressed branch attention with per-row causal mask derived from emission schedule.
We cannot rely on is_causal=True due to S_q != S_kv and variable allowed lengths per t.
"""
B, S, G, h, Dk = Q.shape
S_cmp = K_cmp.shape[2]
device = Q.device
# num_cmp(t) = 0 if t+1 < l else floor((t+1 - l) / d) + 1, clamped to S_cmp
tpos = torch.arange(S, device=device)
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(max=S_cmp)
col = torch.arange(S_cmp, device=device).view(1, S_cmp)
# disallowed mask: True means masked
col >= num_cmp.view(S, 1) # [S,S_cmp]
# Enforce token-level causality as well: no compressed tokens emitted from future blocks beyond t
# When l=d=1, S_cmp == S and this reduces to standard causal
# Parity-first: exact per-t using attention_bgh
out = torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
log("cmp.begin", B=B, S=S, S_cmp=int(S_cmp), l=l, d=d)
for t in range(S):
L = int(num_cmp[t].item())
if L <= 0:
out[:, t] = 0.0
continue
q_t = Q[:, t]
k_t = K_cmp[:, :, :L, :]
v_t = V_cmp[:, :, :L, :]
out[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
log("cmp.step", t=int(t), L=L)
return out
def sliding_window_attention(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S,Dk]
V: torch.Tensor, # [B,G,S,Dv]
w: int,
) -> torch.Tensor: # [B,S,G,h,Dv]
B, S, G, h, Dk = Q.shape
# Empty or zero window → zeros
if w <= 0 or K.shape[2] == 0 or S == 0:
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
device = Q.device
# Build banded causal mask once: allowed keys per row t are [t-w+1 .. t]
row = torch.arange(S, device=device).view(S, 1)
col = torch.arange(S, device=device).view(1, S)
allowed = (col <= row) & (col >= (row - (w - 1))) # [S,S]
# Use additive float mask with -inf for disallowed positions to avoid NaNs
# across SDPA backends/dtypes. Shape: [S,S] then broadcast to [B,G*h,S,S].
Mf2d = torch.full((S, S), float("-inf"), dtype=Q.dtype, device=device)
Mf2d.masked_fill_(allowed, 0.0)
# Prepare SDPA tensors: [B, G*h, S, D*]
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
Kf = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk).contiguous()
Vf = (
V.unsqueeze(2)
.expand(B, G, h, S, V.shape[-1])
.reshape(B, G * h, S, V.shape[-1])
.contiguous()
)
# Broadcast additive mask to [B,G*h,S,S]
Mf = Mf2d.view(1, 1, S, S).expand(B, G * h, S, S)
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
return Of
def grouped_selection_attention(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S_kv,Dk]
V: torch.Tensor, # [B,G,S_kv,Dv]
ranges: torch.Tensor, # [B,S,G,n,2]
) -> torch.Tensor: # [B,S,G,h,Dv]
B, S, G, h, Dk = Q.shape
K.shape[2]
# Path 1: exact sequential-equivalence gather per (b,t,g)
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
for b in range(B):
for t in range(S):
for g in range(G):
# build exact gather index list
idxs = []
for i in range(ranges.shape[3]):
s0 = int(ranges[b, t, g, i, 0].item())
e0 = int(ranges[b, t, g, i, 1].item())
if e0 > s0:
idxs.append(torch.arange(s0, e0, device=V.device))
if idxs:
idx = torch.cat(idxs)
k = K[b, g, idx] # [L,Dk]
v = V[b, g, idx] # [L,Dv]
q = Q[b, t, g] # [h,Dk]
# Expand per-head kv and add query-length dim for SDPA
q_btgh = q.unsqueeze(0).unsqueeze(2) # [1,h,1,Dk]
k_btgh = (
k.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], k.shape[0], k.shape[1])
) # [1,h,L,Dk]
v_btgh = (
v.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], v.shape[0], v.shape[1])
) # [1,h,L,Dv]
q_btgh = q_btgh.contiguous()
k_btgh = k_btgh.contiguous()
v_btgh = v_btgh.contiguous()
attn = F.scaled_dot_product_attention(
q_btgh, k_btgh, v_btgh, is_causal=True
) # [1,h,1,Dv]
out[b, t, g] = attn.squeeze(0).squeeze(1) # [h,Dv]
log("sel.step", b=int(b), t=int(t), g=int(g), L=int(k.shape[0]))
else:
out[b, t, g] = 0.0
log("sel.step", b=int(b), t=int(t), g=int(g), L=0)
return out
def sliding_window_attention_masked(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S,Dk]
V: torch.Tensor, # [B,G,S,Dv]
w: int,
) -> torch.Tensor: # [B,S,G,h,Dv]
# Memory-friendly masked semantics: only the first element in [start..t] is attended.
# With a single allowed key per row, SDPA reduces to returning that V directly.
B, S, G, h, Dk = Q.shape
if w <= 0 or K.shape[2] == 0:
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
device = Q.device
tpos = torch.arange(S, device=device)
start = (tpos - (w - 1)).clamp_min(0) # [S]
# Build per-(B,G,S) gather indices and fetch V at start
idx = start.view(1, 1, S, 1).expand(B, G, S, 1) # [B,G,S,1]
v_sel = torch.gather(V, 2, idx.expand(B, G, S, V.shape[-1])) # [B,G,S,Dv]
# Expand across heads; result [B,S,G,h,Dv]
Of = v_sel.permute(0, 2, 1, 3).unsqueeze(3).expand(B, S, G, h, V.shape[-1])
return Of
def batched_causal_attention_compressed_masked(
Q: torch.Tensor, # [B,S,G,h,Dk]
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
l: int,
d: int,
) -> torch.Tensor: # [B,S,G,h,Dv]
# Memory-friendly masked semantics: if num_cmp(t)>0, attend only to index 0 → return V[:, :, 0].
B, S, G, h, Dk = Q.shape
S_cmp = K_cmp.shape[2]
device = Q.device
if S_cmp == 0:
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
tpos = torch.arange(S, device=device)
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp) # [S]
have_any = (num_cmp > 0).view(1, S, 1, 1, 1).expand(B, S, G, h, 1)
v0 = V_cmp[:, :, 0, :] # [B,G,Dv]
v0f = v0.unsqueeze(1).unsqueeze(3).expand(B, S, G, h, V_cmp.shape[-1])
Of = torch.where(have_any, v0f, torch.zeros_like(v0f))
return Of
def grouped_selection_attention_packed(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S_kv,Dk]
V: torch.Tensor, # [B,G,S_kv,Dv]
ranges: torch.Tensor, # [B,S,G,n,2]
) -> torch.Tensor: # [B,S,G,h,Dv]
"""
Bucketed varlen packing by row length L with parity to gather path.
For each (b,t,g), build its flat index list from ranges, bucket rows
by identical L, and run one SDPA per bucket.
"""
B, S, G, h, Dk = Q.shape
K.shape[2]
device = Q.device
# Initialize output
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
# Flatten to row indices
rows = [] # list of (b,t,g, idx_tensor[L])
lengths = []
for b in range(B):
for t in range(S):
for g in range(G):
idxs = []
for i in range(ranges.shape[3]):
s0 = int(ranges[b, t, g, i, 0].item())
e0 = int(ranges[b, t, g, i, 1].item())
if e0 > s0:
idxs.append(torch.arange(s0, e0, device=device))
if idxs:
idx = torch.cat(idxs)
else:
idx = torch.empty((0,), dtype=torch.long, device=device)
rows.append((b, t, g, idx))
lengths.append(idx.numel())
if not rows:
return out
lengths_t = torch.tensor(lengths, device=device)
unique_L = torch.unique(lengths_t)
# Enable autograd-safe packing during training or when forced by env
use_safe_pack = (
torch.is_grad_enabled() and (Q.requires_grad or K.requires_grad or V.requires_grad)
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
for Lval in unique_L.tolist():
L = int(Lval)
# collect row indices for this bucket
bucket_idx = [i for i, Lx in enumerate(lengths) if Lx == L]
if L == 0 or len(bucket_idx) == 0:
# rows with L=0 remain zeros
continue
N = len(bucket_idx)
if use_safe_pack:
# Graph-friendly packing using stack to preserve autograd links
map_rows = []
Q_list = []
K_list = []
V_list = []
for ridx in bucket_idx:
b, t, g, idx = rows[ridx]
map_rows.append((b, t, g))
Q_list.append(Q[b, t, g]) # [h,Dk]
K_list.append(K[b, g, idx]) # [L,Dk]
V_list.append(V[b, g, idx]) # [L,Dv]
Qb = torch.stack(Q_list, dim=0) # [N,h,Dk]
Kb = torch.stack(K_list, dim=0) # [N,L,Dk]
Vb = torch.stack(V_list, dim=0) # [N,L,Dv]
q_btgh = Qb.unsqueeze(1).permute(0, 2, 1, 3) # [N,h,1,Dk]
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
attn = F.scaled_dot_product_attention(q_btgh, k_btgh, v_btgh, is_causal=True)
Ob = attn.squeeze(2) # [N,h,Dv]
for j, (b, t, g) in enumerate(map_rows):
out[b, t, g] = Ob[j]
else:
# Workspace-backed Q, K, V batches to reduce allocations
ws_key = (str(device), Q.dtype, K.dtype, V.dtype, h, Dk, V.shape[-1])
ws = _SEL_PACK_WS.get(ws_key)
need_new = (
ws is None or ws["Q"].shape[0] < N or ws["K"].shape[1] < L or ws["V"].shape[1] < L
)
if need_new:
# Allow pre-sizing via env to reduce reallocations
# Bounded to prevent excessive memory allocation (max 100K rows, 10K length)
reserve_N = _env_int_bounded("NSA_SEL_PACK_RESERVE_N", 0, 0, 10**5)
reserve_L = _env_int_bounded("NSA_SEL_PACK_RESERVE_L", 0, 0, 10**4)
new_N = max(N, reserve_N)
new_L = max(L, reserve_L)
Qb = torch.empty((new_N, h, Dk), dtype=Q.dtype, device=device)
Kb = torch.empty((new_N, new_L, Dk), dtype=K.dtype, device=device)
Vb = torch.empty((new_N, new_L, V.shape[-1]), dtype=V.dtype, device=device)
_SEL_PACK_WS[ws_key] = {"Q": Qb, "K": Kb, "V": Vb}
else:
Qb = _SEL_PACK_WS[ws_key]["Q"][:N]
Kb = _SEL_PACK_WS[ws_key]["K"][:N, :L]
Vb = _SEL_PACK_WS[ws_key]["V"][:N, :L]
# Populate workspace buffers and perform SDPA (execute for both new and reused workspaces)
map_rows = []
for j, ridx in enumerate(bucket_idx):
b, t, g, idx = rows[ridx]
Qb[j] = Q[b, t, g] # [h,Dk]
Kb[j] = K[b, g, idx] # [L,Dk]
Vb[j] = V[b, g, idx] # [L,Dv]
map_rows.append((b, t, g))
# SDPA per bucket: expand per-head
q_btgh = Qb.unsqueeze(1) # [N,1,h,Dk]
q_btgh = q_btgh.permute(0, 2, 1, 3) # [N,h,1,Dk]
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
attn = F.scaled_dot_product_attention(
q_btgh, k_btgh, v_btgh, is_causal=True
) # [N,h,1,Dv]
Ob = attn.squeeze(2) # [N,h,Dv]
# Scatter back
for j, (b, t, g) in enumerate(map_rows):
out[b, t, g] = Ob[j]
return out
def selection_attention_varlen_all(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S_kv,Dk]
V: torch.Tensor, # [B,G,S_kv,Dv]
ranges: torch.Tensor, # [B,S,G,n,2]
) -> torch.Tensor: # [B,S,G,h,Dv]
"""
Fully batched selection attention using varlen packing across all (B,S,G) rows.
If NSA_SEL_VARLEN_V2 is enabled (default), dispatches to the vectorized v2
packer. Otherwise uses the legacy v1 path (minimal loops with workspace).
"""
# Optional v2 vectorized packer
if os.getenv("NSA_SEL_VARLEN_V2", "1").lower() in ("1", "true", "yes", "on"):
return selection_attention_varlen_all_v2(Q, K, V, ranges)
B, S, G, h, Dk = Q.shape
# Parity override: when enabled, force causal=True to match packed reference
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
if _parity:
# Force exact parity by delegating to the packed reference
return grouped_selection_attention_packed(Q, K, V, ranges)
device = Q.device
Dv = V.shape[-1]
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
# Build row list and lengths from ranges (sum of segment lengths)
rows: list[tuple[int, int, int]] = []
lens: list[int] = []
for b in range(B):
for t in range(S):
for g in range(G):
L = 0
for i in range(ranges.shape[3]):
s0 = int(ranges[b, t, g, i, 0].item())
e0 = int(ranges[b, t, g, i, 1].item())
if e0 > s0:
L += e0 - s0
if L > 0:
rows.append((b, t, g))
lens.append(L)
N = len(rows)
if N == 0:
return out
total_k = int(sum(lens))
# Workspace-backed packing
ws = _get_varlen_workspace(
device,
dtype_q=Q.dtype,
dtype_k=K.dtype,
dtype_v=V.dtype,
h=h,
d_k=Dk,
d_v=Dv,
cap_N=N,
cap_total_k=total_k,
)
q_pack = ws["q"][:N]
k_pack = ws["k"][:total_k]
v_pack = ws["v"][:total_k]
cuq = ws["cuq"][: N + 1]
cuk = ws["cuk"][: N + 1]
# Fill cu_seqlens
cuq.zero_()
cuk.zero_()
# Pack per row
write_pos = 0
for i, (b, t, g) in enumerate(rows):
# q for row
q_pack[i] = Q[b, t, g]
# iterate segments for this row
for j in range(ranges.shape[3]):
s0 = int(ranges[b, t, g, j, 0].item())
e0 = int(ranges[b, t, g, j, 1].item())
if e0 <= s0:
continue
seg_k = K[b, g, s0:e0] # [Lseg,Dk]
seg_v = V[b, g, s0:e0] # [Lseg,Dv]
Lseg = e0 - s0
# Assign using explicit expand_as to match target slice shape and avoid view pitfalls
_kslice = k_pack[write_pos : write_pos + Lseg]
_vslice = v_pack[write_pos : write_pos + Lseg]
_kslice.copy_(seg_k[:, None, :].expand_as(_kslice))
_vslice.copy_(seg_v[:, None, :].expand_as(_vslice))
write_pos += Lseg
cuq[i + 1] = cuq[i] + 1
cuk[i + 1] = cuk[i] + lens[i]
# Try FA‑2 varlen if available and supported. Default non-causal semantics;
# optionally force parity with packed path via NSA_SEL_VARLEN_FORCE_PARITY.
ok, _ = fa2_supported_verbose(device, Q.dtype, Dk)
if ok and is_flash_varlen_available():
try:
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=max(lens),
causal=_parity,
) # [N,h,Dv]
# Scatter back
for i, (b, t, g) in enumerate(rows):
out[b, t, g] = o_pack[i]
return out
except Exception:
pass
# Dense batch per fixed L bucket as fallback
buckets: dict[int, list[int]] = {}
for i, L in enumerate(lens):
buckets.setdefault(L, []).append(i)
for L, idxs in buckets.items():
if L <= 0 or len(idxs) == 0:
continue
Nb = len(idxs)
Qb = torch.empty((Nb, h, Dk), dtype=Q.dtype, device=device)
Kb = torch.empty((Nb, L, Dk), dtype=K.dtype, device=device)
Vb = torch.empty((Nb, L, Dv), dtype=V.dtype, device=device)
tgt: list[tuple[int, int, int]] = []
for j, irow in enumerate(idxs):
b, t, g = rows[irow]
Qb[j] = Q[b, t, g]
# Rebuild fixed-length K/V for this row from ranges
write = 0
for rj in range(ranges.shape[3]):
s0 = int(ranges[b, t, g, rj, 0].item())
e0 = int(ranges[b, t, g, rj, 1].item())
if e0 <= s0:
continue
Lseg = e0 - s0
Kb[j, write : write + Lseg] = K[b, g, s0:e0]
Vb[j, write : write + Lseg] = V[b, g, s0:e0]
write += Lseg
tgt.append((b, t, g))
# Batched dense fallback for this bucket. Default non-causal; optionally force parity.
try:
q_rows = Qb.unsqueeze(1) # [Nb,1,h,Dk]
k_rows = Kb.unsqueeze(2).expand(Nb, L, h, Dk) # [Nb,L,h,Dk]
v_rows = Vb.unsqueeze(2).expand(Nb, L, h, Dv) # [Nb,L,h,Dv]
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=_parity).squeeze(
1
) # [Nb,h,Dv]
for i, (b, t, g) in enumerate(tgt):
out[b, t, g] = Ob[i]
except Exception:
# Final fallback: per-row SDPA
for j, (b, t, g) in enumerate(tgt):
q_btgh = Qb[j].unsqueeze(0).unsqueeze(0) # [1,1,h,Dk]
k_btgh = Kb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dk]
v_btgh = Vb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dv]
out[b, t, g] = attention_bgh(q_btgh, k_btgh, v_btgh, causal=_parity)[0, 0]
return out
def selection_attention_varlen_all_v2(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
ranges: torch.Tensor,
) -> torch.Tensor:
"""
Vectorized v2 varlen selection packer with FA‑2 varlen fast path and dense fallback.
- Eliminates Python loops for packing by using a difference-array mask to build per-row
allowed indices and flat-select K/V tokens.
- Uses causal=False for single‑query rows.
- Env: NSA_SEL_VARLEN_MIN_L to bypass on tiny rows (falls back to packed path).
"""
B, S, G, h, Dk = Q.shape
# Parity override: when enabled, force causal=True to match packed reference
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
if _parity:
# Force exact parity by delegating to the packed reference
return grouped_selection_attention_packed(Q, K, V, ranges)
device = Q.device
Dv = V.shape[-1]
S_kv = K.shape[2]
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
if S_kv == 0:
return out
# Build allowed mask [B,S,G,S_kv]
n = ranges.shape[3]
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv)
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv)
BSG = B * S * G
starts_f = starts.reshape(BSG, n)
ends_f = ends.reshape(BSG, n)
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
diff.scatter_add_(1, starts_f, one)
diff.scatter_add_(1, ends_f, -one)
allowed = diff[:, :-1].cumsum(dim=1).gt(0) # [BSG,S_kv]
lens_flat = allowed.sum(dim=1, dtype=torch.int32) # [BSG]
row_mask = lens_flat.gt(0)
if not torch.any(row_mask):
return out
try:
min_L = int(os.getenv("NSA_SEL_VARLEN_MIN_L", "0"))
except Exception:
min_L = 0
if min_L > 0 and int(lens_flat.max().item()) < min_L:
return grouped_selection_attention_packed(Q, K, V, ranges)
idx_rows = torch.nonzero(row_mask, as_tuple=False).squeeze(1) # [N]
N = int(idx_rows.numel())
# (b,t,g) indices for scatter
b_idx = idx_rows // (S * G)
rem = idx_rows % (S * G)
t_idx = rem // G
g_idx = rem % G
# Pack Q rows
Q_rows = Q.reshape(B * S * G, h, Dk)[idx_rows]
# Map rows to b,g to select K/V
bg_map = (
torch.arange(B, device=device).view(B, 1, 1) * G
+ torch.arange(G, device=device).view(1, 1, G)
).expand(B, S, G)
bg_rows = bg_map.reshape(B * S * G)[idx_rows]
K_bg = K.reshape(B * G, S_kv, Dk)[bg_rows]
V_bg = V.reshape(B * G, S_kv, Dv)[bg_rows]
allowed_rows = allowed[idx_rows]
total_k = int(lens_flat[row_mask].sum().item())
sel_k = K_bg[allowed_rows] # [total_k, Dk]
sel_v = V_bg[allowed_rows] # [total_k, Dv]
lens_sel = lens_flat[row_mask] # [N]
# Workspace-backed packing
ws = _get_varlen_workspace(
device,
dtype_q=Q.dtype,
dtype_k=K.dtype,
dtype_v=V.dtype,
h=h,
d_k=Dk,
d_v=Dv,
cap_N=N,
cap_total_k=total_k,
)
q_pack = ws["q"][:N]
k_pack = ws["k"][:total_k]
v_pack = ws["v"][:total_k]
cuq = ws["cuq"][: N + 1]
cuk = ws["cuk"][: N + 1]
q_pack.copy_(Q_rows)
k_pack.copy_(sel_k.unsqueeze(1).expand(total_k, h, Dk))
v_pack.copy_(sel_v.unsqueeze(1).expand(total_k, h, Dv))
cuq.copy_(torch.arange(0, N + 1, device=device, dtype=torch.int32))
cuk[0] = 0
torch.cumsum(lens_sel.to(torch.int32), dim=0, out=cuk[1:])
# FA‑2 varlen (non-causal)
ok, _why = fa2_supported_verbose(device, Q.dtype, Dk)
max_len = int(lens_sel.max().item())
if ok and is_flash_varlen_available():
try:
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=max_len,
causal=_parity,
)
out[b_idx, t_idx, g_idx] = o_pack
return out
except Exception:
pass
# Correctness-first fallback: masked SDPA over an allowed key mask
# This path matches the non-causal packed reference exactly and avoids
# potential packing/indexing pitfalls in dense-bucket fallbacks.
try:
return grouped_selection_attention_masked(Q, K, V, ranges)
except Exception:
pass
# Legacy dense fallback by length buckets (kept as a final fallback)
starts = cuk[:-1].to(torch.int64)
ends = cuk[1:].to(torch.int64)
Ls = (ends - starts).to(torch.int64)
for L in torch.unique(Ls).tolist():
if L <= 0:
continue
sel = (Ls == L).nonzero(as_tuple=False).squeeze(1)
if sel.numel() == 0:
continue
Nb = int(sel.numel())
Qb = q_pack[sel]
k_rows = torch.empty((Nb, L, h, Dk), dtype=K.dtype, device=device)
v_rows = torch.empty((Nb, L, h, Dv), dtype=V.dtype, device=device)
for j in range(Nb):
s0 = int(starts[sel[j]].item())
e0 = int(ends[sel[j]].item())
k_rows[j] = k_pack[s0:e0]
v_rows[j] = v_pack[s0:e0]
try:
Ob = attention_fa2_dense_batch(Qb.unsqueeze(1), k_rows, v_rows, causal=_parity).squeeze(1)
except Exception:
Ob = torch.empty((Nb, h, Dv), dtype=V.dtype, device=device)
for j in range(Nb):
Ob[j] = attention_bgh(Qb[j].unsqueeze(0), k_rows[j].unsqueeze(0), v_rows[j].unsqueeze(0), causal=_parity)[
0
]
out[b_idx[sel], t_idx[sel], g_idx[sel]] = Ob
return out
def grouped_selection_attention_masked(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S_kv,Dk]
V: torch.Tensor, # [B,G,S_kv,Dv]
ranges: torch.Tensor, # [B,S,G,n,2]
) -> torch.Tensor: # [B,S,G,h,Dv]
"""
Fully batched selection attention using an additive -inf mask.
Vectorized ranges→mask construction via prefix-sum trick (no Python loops).
"""
B, S, G, h, Dk = Q.shape
S_kv = K.shape[2]
device = Q.device
if S_kv == 0:
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
# Vectorized allowed mask [B,S,G,S_kv] from ranges using difference array
n = ranges.shape[3]
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
BSG = B * S * G
starts_f = starts.reshape(BSG, n)
ends_f = ends.reshape(BSG, n)
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
diff.scatter_add_(1, starts_f, one)
diff.scatter_add_(1, ends_f, -one)
allowed = diff[:, :-1].cumsum(dim=1).gt(0).reshape(B, S, G, S_kv)
# Detect rows with no allowed keys (all False along key dimension)
row_has_any = allowed.any(dim=-1) # [B,S,G]
row_empty = ~row_has_any
# Prevent SDPA from seeing an all-−inf row which can produce NaNs.
# For originally empty rows, force a single safe key (index 0) to True,
# run SDPA, then zero their outputs afterward to preserve semantics.
if row_empty.any():
allowed_safe = allowed.clone()
flat = allowed_safe.view(B * S * G, S_kv)
row_empty_flat = row_empty.reshape(B * S * G)
if S_kv > 0:
flat[row_empty_flat, 0] = True
allowed_safe = flat.view_as(allowed_safe)
else:
allowed_safe = allowed
# Prepare SDPA tensors: [B,G*h,S, D*] and mask [B,G*h,S,S_kv]
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
Kf = K.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, Dk).contiguous()
Vf = V.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, V.shape[-1]).contiguous()
# Build additive mask in float32 for numerical stability with -inf
zeros = torch.zeros((B, G * h, S, S_kv), dtype=torch.float32, device=device)
neg_inf = torch.full((B, G * h, S, S_kv), float("-inf"), dtype=torch.float32, device=device)
Mf = torch.where(
allowed_safe.transpose(1, 2) # [B,G,S,S_kv]
.unsqueeze(2)
.expand(-1, -1, h, -1, -1)
.reshape(B, G * h, S, S_kv),
zeros,
neg_inf,
).contiguous()
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
# Zero outputs for originally empty rows to preserve semantics
if row_empty.any():
Of = torch.where(row_has_any.unsqueeze(-1).unsqueeze(-1), Of, torch.zeros_like(Of))
return Of
# ===== FA-2 integration scaffolding (M1) =====
def _env_bool(name: str, default: bool = False) -> bool:
v = os.getenv(name, "1" if default else "0").lower()
return v in ("1", "true", "yes", "on")
def _is_sm89(device: torch.device) -> bool:
"""Return True if running on CUDA device with SM 8.9 (Ada/RTX 4090)."""
if device.type != "cuda":
return False
try:
cap = torch.cuda.get_device_capability(device)
return cap == (8, 9)
except Exception:
return False
def _fa2_forced() -> bool:
"""Return True if FA-2 usage is explicitly forced via env."""
return _env_bool("NSA_FA2_FORCE", False)
def sliding_window_attention_fa2(
Q: torch.Tensor, # [B,S,G,h,Dk]
K: torch.Tensor, # [B,G,S,Dk]
V: torch.Tensor, # [B,G,S,Dv]
w: int,
min_len_for_fa2: int = 16,
) -> torch.Tensor:
"""
Planned FA-2 path for sliding with safe fallbacks.
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
"""
B, S, G, h, Dk = Q.shape
device = Q.device
# Policy: sliding FA-2 is disabled by default due to API semantics
# limitation (causal mask assumes start at 0). Allow only if explicitly
# enabled via NSA_ALLOW_SLIDING_FA2 or forced flags.
allow_sliding_fa2 = _env_bool("NSA_ALLOW_SLIDING_FA2", False)
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
if _is_sm89(device) and not _fa2_forced():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="win", reason="sm89_guard", forced=bool(_fa2_forced()))
return sliding_window_attention(Q, K, V, w)
# Policy guard
if not allow_sliding_fa2 and not (
_env_bool("NSA_FA2_FORCE_VARLEN", False) or _env_bool("NSA_FA2_FORCE_DENSE", False)
):
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="win", reason="unsupported_sliding_semantics", forced=False)
return sliding_window_attention(Q, K, V, w)
# Compute effective per-row window lengths and buckets
lengths = compute_sliding_lengths(S, w, device)
max_len = int(lengths.max().item()) if lengths.numel() > 0 else 0
# Allow override via env
try:
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_WIN", str(min_len_for_fa2)))
except Exception:
pass
# Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
if min_len_for_fa2 <= 0:
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="win", reason="disabled_threshold")
return sliding_window_attention(Q, K, V, w)
buckets = build_length_buckets(lengths)
if buckets:
log("fa2.win.buckets", n=len(buckets), max_len=max_len)
# Build cu_seqlens per bucket (for future FA-2 varlen call)
for idx in buckets:
blens = lengths[idx]
_ = build_cu_seqlens_for_buckets(blens)
# Small-length auto-switch to masked SDPA
if max_len < min_len_for_fa2:
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="win",
reason="below_min_len",
max_len=int(max_len),
min_len=int(min_len_for_fa2),
)
return sliding_window_attention(Q, K, V, w)
# Capability check
ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
if not ok or not is_flash_varlen_available():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="win", reason=why, has_varlen=is_flash_varlen_available())
return sliding_window_attention(Q, K, V, w)
# Attempt FA-2 across all rows using varlen first, then dense per-bucket. Fallback to masked SDPA on error.
try:
B, S, G, h, Dk = Q.shape
Dv = V.shape[-1]
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
force_win_dense = _env_bool("NSA_WIN_FORCE_DENSE", False)
# Log histogram of lengths
if buckets:
uniq, counts = torch.unique(lengths, return_counts=True)
log("fa2.win.hist", uniq=uniq.tolist(), counts=counts.tolist())
# Try a single varlen call across all rows
if (is_flash_varlen_available() and not (force_dense or force_win_dense)) or force_varlen:
rows = []
len_rows = []
for t in range(S):
L = int(lengths[t].item())
for b in range(B):
for g in range(G):
rows.append((b, t, g))
len_rows.append(L)
N = len(rows)
if N > 0 and max_len >= 1:
use_safe_pack = (
torch.is_grad_enabled()
and (Q.requires_grad or K.requires_grad or V.requires_grad)
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
if use_safe_pack:
# Autograd-safe packing via stack/cat to preserve graph links
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0) # [N,h,Dk]
k_rows = []
v_rows = []
for i, (b, t, g) in enumerate(rows):
L = len_rows[i]
if L > 0:
start = max(0, (t + 1) - w)
end = t + 1
seg_k = K[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dk]
seg_v = V[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dv]
k_rows.append(seg_k)
v_rows.append(seg_v)
total_k = int(sum(len_rows))
if total_k > 0:
k_pack = torch.cat(k_rows, dim=0)
v_pack = torch.cat(v_rows, dim=0)
else:
k_pack = torch.zeros((0, h, Dk), dtype=K.dtype, device=K.device)
v_pack = torch.zeros((0, h, Dv), dtype=V.dtype, device=V.device)
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
else:
total_k = int(sum(len_rows))
ws = _get_varlen_workspace(
Q.device, Q.dtype, K.dtype, V.dtype, h, Dk, Dv, N, total_k
)
q_pack = ws["q"][:N]
k_pack = ws["k"][:total_k]
v_pack = ws["v"][:total_k]
# Build cumulative sequence lengths for Q and K
cuq = ws["cuq"][: N + 1]
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
cuk = ws["cuk"][: N + 1]
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
# Fill packs
write_pos = 0
for i, (b, t, g) in enumerate(rows):
L = len_rows[i]
q_pack[i] = Q[b, t, g]
if L > 0:
start = max(0, (t + 1) - w)
end = t + 1
seg_k = K[b, g, start:end] # [L,Dk]
seg_v = V[b, g, start:end] # [L,Dv]
assert (write_pos + L) <= total_k, "varlen K/V pack overflow"
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
write_pos += L
# Optional integrity checks (debug only)
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
try:
assert cuq.numel() == (N + 1), "cuq length mismatch"
assert cuk.numel() == (N + 1), "cuk length mismatch"
assert int(cuk[-1].item()) == int(total_k), "cuk total_k mismatch"
if total_k > 0 and N > 0:
probe = [0, N // 2, N - 1] if N >= 3 else [0]
for i in probe:
L_i = int(len_rows[i])
b_i, t_i, g_i = rows[i]
s_i = int(max(0, (t_i + 1) - w))
e_i = int(t_i + 1)
if L_i > 0:
ks = k_pack[cuk[i] : cuk[i + 1]] # [L,h,Dk]
kv = K[b_i, g_i, s_i:e_i].unsqueeze(1).expand(-1, h, -1)
if ks.shape != kv.shape:
log(
"warn.fa2_win_pack_shape",
row=i,
ks=ks.shape,
kv=kv.shape,
)
else:
md = float((ks - kv).abs().max().item())
if md > 1e-3:
log(
"warn.fa2_win_pack_mismatch",
row=i,
L=L_i,
max_diff=md,
)
except Exception:
pass
if use_timing:
t0 = time.perf_counter()
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=max_len,
causal=False,
) # [N,h,Dv]
if not torch.isfinite(o_pack).all():
log("warn.fa2_win_varlen_nonfinite")
return sliding_window_attention(Q, K, V, w)
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.win.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
# Scatter back
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
for i, (b, t, g) in enumerate(rows):
out[b, t, g] = o_pack[i]
return out
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
for idx in buckets:
if idx.numel() == 0:
continue
L = int(lengths[idx[0]].item())
# Collect rows for this bucket
rows_q = [] # [N,h,Dk]
rows_k = [] # [N,L,Dk]
rows_v = [] # [N,L,Dv]
tgt = []
for t in idx.tolist():
start = max(0, (t + 1) - w)
end = t + 1
for b in range(B):
for g in range(G):
rows_q.append(Q[b, t, g])
rows_k.append(K[b, g, start:end])
rows_v.append(V[b, g, start:end])
tgt.append((b, t, g))
if not rows_q:
continue
N = len(rows_q)
Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
if is_flash_varlen_available() and not (force_dense or force_win_dense):
# Pack varlen (constant L here, but use API for generality)
q_pack = Qb # [N,h,Dk]
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
if use_timing:
t0 = time.perf_counter()
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=L,
causal=False,
) # [N,h,Dv]
if not torch.isfinite(o_pack).all():
log("warn.fa2_win_bucket_nonfinite")
return sliding_window_attention(Q, K, V, w)
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.win.bucket", path="varlen", L=L, N=int(N), ms=dt)
Ob = o_pack # [N,h,Dv]
else:
q_rows = Qb.unsqueeze(1) # [N,1,h,Dk]
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
if use_timing:
t0 = time.perf_counter()
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False).squeeze(
1
) # [N,h,Dv]
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.win.bucket", path="dense", L=L, N=int(N), ms=dt)
for i, (b, t, g) in enumerate(tgt):
out[b, t, g] = Ob[i]
return out
except Exception as e:
log("warn.fa2_unexpected_fallback", branch="win", error=str(e)[:100])
return sliding_window_attention_masked(Q, K, V, w)
def compressed_attention_fa2(
Q: torch.Tensor, # [B,S,G,h,Dk]
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
l: int,
d: int,
min_len_for_fa2: int = 16,
) -> torch.Tensor:
"""
Planned FA-2 path for compressed with safe fallbacks.
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
"""
B, S, G, h, Dk = Q.shape
device = Q.device
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
if _is_sm89(device) and not _fa2_forced():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="cmp", reason="sm89_guard", forced=bool(_fa2_forced()))
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
S_cmp = K_cmp.shape[2]
if S_cmp == 0:
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
num_cmp = compute_compressed_lengths(S, l, d, S_cmp, device)
max_len = int(num_cmp.max().item()) if num_cmp.numel() > 0 else 0
try:
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_CMP", str(min_len_for_fa2)))
except Exception:
pass
# Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
if min_len_for_fa2 <= 0:
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="cmp", reason="disabled_threshold")
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
buckets = build_length_buckets(num_cmp)
if buckets:
log("fa2.cmp.buckets", n=len(buckets), max_len=max_len)
for idx in buckets:
blens = num_cmp[idx]
_ = build_cu_seqlens_for_buckets(blens)
if max_len < min_len_for_fa2:
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="cmp",
reason="below_min_len",
max_len=int(max_len),
min_len=int(min_len_for_fa2),
)
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
if not ok or not is_flash_varlen_available():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="cmp", reason=why, has_varlen=is_flash_varlen_available())
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
try:
Dv = V_cmp.shape[-1]
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
# Log histogram of lengths
if buckets:
uniq, counts = torch.unique(num_cmp, return_counts=True)
log("fa2.cmp.hist", uniq=uniq.tolist(), counts=counts.tolist())
# Try single varlen across all rows with L>0
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
if ((is_flash_varlen_available() and not force_dense) or force_varlen) and max_len >= 1:
rows = []
len_rows = []
for t in range(S):
L = int(num_cmp[t].item())
for b in range(B):
for g in range(G):
if L > 0:
rows.append((b, t, g))
len_rows.append(L)
N = len(rows)
if N > 0:
total_k = int(sum(len_rows))
use_safe_pack = (
torch.is_grad_enabled()
and (Q.requires_grad or K_cmp.requires_grad or V_cmp.requires_grad)
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
if use_safe_pack:
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0)
k_rows = []
v_rows = []
for (b, t, g), L in zip(rows, len_rows):
if L > 0:
seg_k = K_cmp[b, g, :L]
seg_v = V_cmp[b, g, :L]
k_rows.append(seg_k.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dk]
v_rows.append(seg_v.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dv]
if total_k > 0:
k_pack = torch.cat(k_rows, dim=0)
v_pack = torch.cat(v_rows, dim=0)
else:
k_pack = torch.zeros((0, h, Dk), dtype=K_cmp.dtype, device=K_cmp.device)
v_pack = torch.zeros((0, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
else:
ws = _get_varlen_workspace(
Q.device, Q.dtype, K_cmp.dtype, V_cmp.dtype, h, Dk, Dv, N, total_k
)
q_pack = ws["q"][:N]
k_pack = ws["k"][:total_k]
v_pack = ws["v"][:total_k]
cuq = ws["cuq"][: N + 1]
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
cuk = ws["cuk"][: N + 1]
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
write_pos = 0
for i, (b, t, g) in enumerate(rows):
L = len_rows[i]
q_pack[i] = Q[b, t, g]
if L > 0:
seg_k = K_cmp[b, g, :L]
seg_v = V_cmp[b, g, :L]
assert (write_pos + L) <= total_k, "varlen cmp K/V pack overflow"
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
write_pos += L
if use_timing:
t0 = time.perf_counter()
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=max_len,
causal=False,
) # [N,h,Dv]
if not torch.isfinite(o_pack).all():
log("warn.fa2_cmp_varlen_nonfinite")
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.cmp.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
for i, (b, t, g) in enumerate(rows):
out[b, t, g] = o_pack[i]
return out
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
for idx in buckets:
if idx.numel() == 0:
continue
L = int(num_cmp[idx[0]].item())
rows_q = [] # [N,h,Dk]
rows_k = [] # [N,L,Dk]
rows_v = [] # [N,L, Dv]
tgt = []
for t in idx.tolist():
if L <= 0:
continue
for b in range(B):
for g in range(G):
rows_q.append(Q[b, t, g])
rows_k.append(K_cmp[b, g, :L])
rows_v.append(V_cmp[b, g, :L])
tgt.append((b, t, g))
if not rows_q:
continue
N = len(rows_q)
Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
if is_flash_varlen_available() and not force_dense:
q_pack = Qb
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
if use_timing:
t0 = time.perf_counter()
o_pack = attention_fa2_varlen(
q_pack,
k_pack,
v_pack,
cuq,
cuk,
max_seqlen_q=1,
max_seqlen_k=L,
causal=False,
) # [N,h,Dv]
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.cmp.bucket", path="varlen", L=L, N=int(N), ms=dt)
Ob = o_pack
else:
q_rows = Qb.unsqueeze(1)
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
if use_timing:
t0 = time.perf_counter()
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=True).squeeze(1)
if not torch.isfinite(Ob).all():
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
if use_timing:
dt = (time.perf_counter() - t0) * 1e3
log("fa2.cmp.bucket", path="dense", L=L, N=int(N), ms=dt)
for i, (b, t, g) in enumerate(tgt):
out[b, t, g] = Ob[i]
return out
except Exception as e:
log("warn.fa2_unexpected_fallback", branch="cmp", error=str(e)[:100])
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
def sliding_window_attention_fa2_decode(
q_t: torch.Tensor, K_win: torch.Tensor, V_win: torch.Tensor, w: int
) -> torch.Tensor:
B, G, h, Dk = q_t.shape
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
if _is_sm89(q_t.device) and not _fa2_forced():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="win.decode",
reason="sm89_guard",
forced=bool(_fa2_forced()),
)
end = K_win.shape[2]
win_len = min(w, end)
if win_len == 0:
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
start = end - win_len
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
end = K_win.shape[2]
win_len = min(w, end)
if win_len == 0:
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
# CPU or unsupported: direct SDPA for parity
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
if not ok:
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="win.decode", reason=why)
start = end - win_len
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
# Small-length auto-switch for decode
try:
min_len = int(os.getenv("NSA_FA2_MIN_LEN_WIN", "16"))
except Exception:
min_len = 16
if min_len < 1:
min_len = 1
if win_len < min_len:
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="win.decode",
reason="below_min_len",
win_len=int(win_len),
min_len=int(min_len),
)
start = end - win_len
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
start = end - win_len
k = K_win[:, :, start:end]
v = V_win[:, :, start:end]
N = B * G
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1) # [N,1,h,Dk]
k_rows = k.reshape(N, win_len, Dk).unsqueeze(2).expand(N, win_len, h, Dk)
v_rows = v.reshape(N, win_len, v.shape[-1]).unsqueeze(2).expand(N, win_len, h, v.shape[-1])
try:
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False) # [N,1,h,Dv]
o = o.squeeze(1).reshape(B, G, h, -1)
if not torch.isfinite(o).all():
return attention_bgh(q_t, k, v, causal=True)
return o
except Exception as e:
log("warn.fa2_unexpected_fallback", branch="win.decode", error=str(e)[:100])
return attention_bgh(q_t, k, v, causal=True)
def compressed_attention_fa2_decode(
q_t: torch.Tensor, K_cmp: torch.Tensor, V_cmp: torch.Tensor, L: int
) -> torch.Tensor:
if L <= 0:
B, G, h, _ = q_t.shape
return torch.zeros((B, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
B, G, h, Dk = q_t.shape
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
if _is_sm89(q_t.device) and not _fa2_forced():
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="cmp.decode",
reason="sm89_guard",
forced=bool(_fa2_forced()),
)
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
if not ok:
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
log("fa2.gate_skip", branch="cmp.decode", reason=why)
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
try:
min_len = int(os.getenv("NSA_FA2_MIN_LEN_CMP", "16"))
except Exception:
min_len = 16
if min_len < 1:
min_len = 1
if L < min_len:
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
log(
"fa2.gate_skip",
branch="cmp.decode",
reason="below_min_len",
L=int(L),
min_len=int(min_len),
)
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
k = K_cmp[:, :, :L]
v = V_cmp[:, :, :L]
N = B * G
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1)
k_rows = k.reshape(N, L, Dk).unsqueeze(2).expand(N, L, h, Dk)
v_rows = v.reshape(N, L, v.shape[-1]).unsqueeze(2).expand(N, L, h, v.shape[-1])
try:
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False)
o = o.squeeze(1).reshape(B, G, h, -1)
if not torch.isfinite(o).all():
return attention_bgh(q_t, k, v, causal=True)
return o
except Exception:
return attention_bgh(q_t, k, v, causal=True)