|
|
from __future__ import annotations |
|
|
import os |
|
|
import time |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from nsa.core.debug import log |
|
|
from nsa.core.packing import ( |
|
|
build_cu_seqlens_for_buckets, |
|
|
build_length_buckets, |
|
|
compute_compressed_lengths, |
|
|
compute_sliding_lengths, |
|
|
) |
|
|
from nsa.kernels.flash_wrappers import ( |
|
|
attention_bgh, |
|
|
attention_fa2_dense_batch, |
|
|
attention_fa2_varlen, |
|
|
fa2_supported, |
|
|
fa2_supported_verbose, |
|
|
is_flash_varlen_available, |
|
|
) |
|
|
|
|
|
|
|
|
_VARLEN_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {} |
|
|
_SEL_PACK_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {} |
|
|
|
|
|
|
|
|
def _env_int(name: str, default: int) -> int: |
|
|
try: |
|
|
v = int(os.getenv(name, str(default))) |
|
|
return v |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
|
|
|
def _env_int_bounded(name: str, default: int, min_val: int = 0, max_val: int = 10**8) -> int: |
|
|
"""Read integer from environment with bounds checking to prevent excessive memory allocation.""" |
|
|
try: |
|
|
v = int(os.getenv(name, str(default))) |
|
|
if v < min_val: |
|
|
return min_val |
|
|
if v > max_val: |
|
|
|
|
|
import warnings |
|
|
|
|
|
warnings.warn(f"{name}={v} exceeds maximum {max_val}, clamping to {max_val}") |
|
|
return max_val |
|
|
return v |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
|
|
|
def clear_varlen_workspaces() -> None: |
|
|
"""Optional memory cleanup: free varlen packing workspaces.""" |
|
|
_VARLEN_WS.clear() |
|
|
|
|
|
|
|
|
def clear_selection_pack_workspaces() -> None: |
|
|
"""Optional memory cleanup: free selection pack workspaces.""" |
|
|
_SEL_PACK_WS.clear() |
|
|
|
|
|
|
|
|
def _get_varlen_workspace( |
|
|
device: torch.device, |
|
|
dtype_q: torch.dtype, |
|
|
dtype_k: torch.dtype, |
|
|
dtype_v: torch.dtype, |
|
|
h: int, |
|
|
d_k: int, |
|
|
d_v: int, |
|
|
cap_N: int, |
|
|
cap_total_k: int, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
key = (str(device), dtype_q, dtype_k, dtype_v, h, d_k, d_v) |
|
|
ws = _VARLEN_WS.get(key) |
|
|
need_new = ws is None |
|
|
if not need_new: |
|
|
q, k, v = ws["q"], ws["k"], ws["v"] |
|
|
cuq, cuk = ws["cuq"], ws["cuk"] |
|
|
need_new = ( |
|
|
q.shape[0] < cap_N |
|
|
or k.shape[0] < cap_total_k |
|
|
or v.shape[0] < cap_total_k |
|
|
or cuq.numel() < (cap_N + 1) |
|
|
or cuk.numel() < (cap_N + 1) |
|
|
) |
|
|
if need_new: |
|
|
|
|
|
|
|
|
reserve_N = _env_int_bounded("NSA_VARLEN_RESERVE_N", 0, 0, 10**6) |
|
|
reserve_K = _env_int_bounded("NSA_VARLEN_RESERVE_K", 0, 0, 10**8) |
|
|
new_N = max(cap_N, reserve_N, 1) |
|
|
new_K = max(cap_total_k, reserve_K, 1) |
|
|
ws = { |
|
|
"q": torch.empty((new_N, h, d_k), dtype=dtype_q, device=device), |
|
|
"k": torch.empty((new_K, h, d_k), dtype=dtype_k, device=device), |
|
|
"v": torch.empty((new_K, h, d_v), dtype=dtype_v, device=device), |
|
|
"cuq": torch.empty((new_N + 1,), dtype=torch.int32, device=device), |
|
|
"cuk": torch.empty((new_N + 1,), dtype=torch.int32, device=device), |
|
|
} |
|
|
_VARLEN_WS[key] = ws |
|
|
return ws |
|
|
|
|
|
|
|
|
def batched_causal_attention_compressed( |
|
|
Q: torch.Tensor, |
|
|
K_cmp: torch.Tensor, |
|
|
V_cmp: torch.Tensor, |
|
|
l: int, |
|
|
d: int, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compressed branch attention with per-row causal mask derived from emission schedule. |
|
|
We cannot rely on is_causal=True due to S_q != S_kv and variable allowed lengths per t. |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
S_cmp = K_cmp.shape[2] |
|
|
device = Q.device |
|
|
|
|
|
|
|
|
tpos = torch.arange(S, device=device) |
|
|
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(max=S_cmp) |
|
|
col = torch.arange(S_cmp, device=device).view(1, S_cmp) |
|
|
|
|
|
col >= num_cmp.view(S, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
log("cmp.begin", B=B, S=S, S_cmp=int(S_cmp), l=l, d=d) |
|
|
for t in range(S): |
|
|
L = int(num_cmp[t].item()) |
|
|
if L <= 0: |
|
|
out[:, t] = 0.0 |
|
|
continue |
|
|
q_t = Q[:, t] |
|
|
k_t = K_cmp[:, :, :L, :] |
|
|
v_t = V_cmp[:, :, :L, :] |
|
|
out[:, t] = attention_bgh(q_t, k_t, v_t, causal=True) |
|
|
log("cmp.step", t=int(t), L=L) |
|
|
return out |
|
|
|
|
|
|
|
|
def sliding_window_attention( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
w: int, |
|
|
) -> torch.Tensor: |
|
|
B, S, G, h, Dk = Q.shape |
|
|
|
|
|
if w <= 0 or K.shape[2] == 0 or S == 0: |
|
|
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device) |
|
|
device = Q.device |
|
|
|
|
|
row = torch.arange(S, device=device).view(S, 1) |
|
|
col = torch.arange(S, device=device).view(1, S) |
|
|
allowed = (col <= row) & (col >= (row - (w - 1))) |
|
|
|
|
|
|
|
|
Mf2d = torch.full((S, S), float("-inf"), dtype=Q.dtype, device=device) |
|
|
Mf2d.masked_fill_(allowed, 0.0) |
|
|
|
|
|
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() |
|
|
Kf = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk).contiguous() |
|
|
Vf = ( |
|
|
V.unsqueeze(2) |
|
|
.expand(B, G, h, S, V.shape[-1]) |
|
|
.reshape(B, G * h, S, V.shape[-1]) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
Mf = Mf2d.view(1, 1, S, S).expand(B, G * h, S, S) |
|
|
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) |
|
|
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1]) |
|
|
return Of |
|
|
|
|
|
|
|
|
def grouped_selection_attention( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
ranges: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
B, S, G, h, Dk = Q.shape |
|
|
K.shape[2] |
|
|
|
|
|
|
|
|
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device) |
|
|
for b in range(B): |
|
|
for t in range(S): |
|
|
for g in range(G): |
|
|
|
|
|
idxs = [] |
|
|
for i in range(ranges.shape[3]): |
|
|
s0 = int(ranges[b, t, g, i, 0].item()) |
|
|
e0 = int(ranges[b, t, g, i, 1].item()) |
|
|
if e0 > s0: |
|
|
idxs.append(torch.arange(s0, e0, device=V.device)) |
|
|
if idxs: |
|
|
idx = torch.cat(idxs) |
|
|
k = K[b, g, idx] |
|
|
v = V[b, g, idx] |
|
|
q = Q[b, t, g] |
|
|
|
|
|
q_btgh = q.unsqueeze(0).unsqueeze(2) |
|
|
k_btgh = ( |
|
|
k.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], k.shape[0], k.shape[1]) |
|
|
) |
|
|
v_btgh = ( |
|
|
v.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], v.shape[0], v.shape[1]) |
|
|
) |
|
|
q_btgh = q_btgh.contiguous() |
|
|
k_btgh = k_btgh.contiguous() |
|
|
v_btgh = v_btgh.contiguous() |
|
|
attn = F.scaled_dot_product_attention( |
|
|
q_btgh, k_btgh, v_btgh, is_causal=True |
|
|
) |
|
|
out[b, t, g] = attn.squeeze(0).squeeze(1) |
|
|
log("sel.step", b=int(b), t=int(t), g=int(g), L=int(k.shape[0])) |
|
|
else: |
|
|
out[b, t, g] = 0.0 |
|
|
log("sel.step", b=int(b), t=int(t), g=int(g), L=0) |
|
|
return out |
|
|
|
|
|
|
|
|
def sliding_window_attention_masked( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
w: int, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
B, S, G, h, Dk = Q.shape |
|
|
if w <= 0 or K.shape[2] == 0: |
|
|
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device) |
|
|
device = Q.device |
|
|
tpos = torch.arange(S, device=device) |
|
|
start = (tpos - (w - 1)).clamp_min(0) |
|
|
|
|
|
idx = start.view(1, 1, S, 1).expand(B, G, S, 1) |
|
|
v_sel = torch.gather(V, 2, idx.expand(B, G, S, V.shape[-1])) |
|
|
|
|
|
Of = v_sel.permute(0, 2, 1, 3).unsqueeze(3).expand(B, S, G, h, V.shape[-1]) |
|
|
return Of |
|
|
|
|
|
|
|
|
def batched_causal_attention_compressed_masked( |
|
|
Q: torch.Tensor, |
|
|
K_cmp: torch.Tensor, |
|
|
V_cmp: torch.Tensor, |
|
|
l: int, |
|
|
d: int, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
B, S, G, h, Dk = Q.shape |
|
|
S_cmp = K_cmp.shape[2] |
|
|
device = Q.device |
|
|
if S_cmp == 0: |
|
|
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
tpos = torch.arange(S, device=device) |
|
|
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp) |
|
|
have_any = (num_cmp > 0).view(1, S, 1, 1, 1).expand(B, S, G, h, 1) |
|
|
v0 = V_cmp[:, :, 0, :] |
|
|
v0f = v0.unsqueeze(1).unsqueeze(3).expand(B, S, G, h, V_cmp.shape[-1]) |
|
|
Of = torch.where(have_any, v0f, torch.zeros_like(v0f)) |
|
|
return Of |
|
|
|
|
|
|
|
|
def grouped_selection_attention_packed( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
ranges: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Bucketed varlen packing by row length L with parity to gather path. |
|
|
For each (b,t,g), build its flat index list from ranges, bucket rows |
|
|
by identical L, and run one SDPA per bucket. |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
K.shape[2] |
|
|
device = Q.device |
|
|
|
|
|
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device) |
|
|
|
|
|
rows = [] |
|
|
lengths = [] |
|
|
for b in range(B): |
|
|
for t in range(S): |
|
|
for g in range(G): |
|
|
idxs = [] |
|
|
for i in range(ranges.shape[3]): |
|
|
s0 = int(ranges[b, t, g, i, 0].item()) |
|
|
e0 = int(ranges[b, t, g, i, 1].item()) |
|
|
if e0 > s0: |
|
|
idxs.append(torch.arange(s0, e0, device=device)) |
|
|
if idxs: |
|
|
idx = torch.cat(idxs) |
|
|
else: |
|
|
idx = torch.empty((0,), dtype=torch.long, device=device) |
|
|
rows.append((b, t, g, idx)) |
|
|
lengths.append(idx.numel()) |
|
|
if not rows: |
|
|
return out |
|
|
lengths_t = torch.tensor(lengths, device=device) |
|
|
unique_L = torch.unique(lengths_t) |
|
|
|
|
|
use_safe_pack = ( |
|
|
torch.is_grad_enabled() and (Q.requires_grad or K.requires_grad or V.requires_grad) |
|
|
) or _env_bool("NSA_TRAIN_SAFE_PACK", False) |
|
|
|
|
|
for Lval in unique_L.tolist(): |
|
|
L = int(Lval) |
|
|
|
|
|
bucket_idx = [i for i, Lx in enumerate(lengths) if Lx == L] |
|
|
if L == 0 or len(bucket_idx) == 0: |
|
|
|
|
|
continue |
|
|
N = len(bucket_idx) |
|
|
if use_safe_pack: |
|
|
|
|
|
map_rows = [] |
|
|
Q_list = [] |
|
|
K_list = [] |
|
|
V_list = [] |
|
|
for ridx in bucket_idx: |
|
|
b, t, g, idx = rows[ridx] |
|
|
map_rows.append((b, t, g)) |
|
|
Q_list.append(Q[b, t, g]) |
|
|
K_list.append(K[b, g, idx]) |
|
|
V_list.append(V[b, g, idx]) |
|
|
Qb = torch.stack(Q_list, dim=0) |
|
|
Kb = torch.stack(K_list, dim=0) |
|
|
Vb = torch.stack(V_list, dim=0) |
|
|
q_btgh = Qb.unsqueeze(1).permute(0, 2, 1, 3) |
|
|
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk) |
|
|
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1]) |
|
|
attn = F.scaled_dot_product_attention(q_btgh, k_btgh, v_btgh, is_causal=True) |
|
|
Ob = attn.squeeze(2) |
|
|
for j, (b, t, g) in enumerate(map_rows): |
|
|
out[b, t, g] = Ob[j] |
|
|
else: |
|
|
|
|
|
ws_key = (str(device), Q.dtype, K.dtype, V.dtype, h, Dk, V.shape[-1]) |
|
|
ws = _SEL_PACK_WS.get(ws_key) |
|
|
need_new = ( |
|
|
ws is None or ws["Q"].shape[0] < N or ws["K"].shape[1] < L or ws["V"].shape[1] < L |
|
|
) |
|
|
if need_new: |
|
|
|
|
|
|
|
|
reserve_N = _env_int_bounded("NSA_SEL_PACK_RESERVE_N", 0, 0, 10**5) |
|
|
reserve_L = _env_int_bounded("NSA_SEL_PACK_RESERVE_L", 0, 0, 10**4) |
|
|
new_N = max(N, reserve_N) |
|
|
new_L = max(L, reserve_L) |
|
|
Qb = torch.empty((new_N, h, Dk), dtype=Q.dtype, device=device) |
|
|
Kb = torch.empty((new_N, new_L, Dk), dtype=K.dtype, device=device) |
|
|
Vb = torch.empty((new_N, new_L, V.shape[-1]), dtype=V.dtype, device=device) |
|
|
_SEL_PACK_WS[ws_key] = {"Q": Qb, "K": Kb, "V": Vb} |
|
|
else: |
|
|
Qb = _SEL_PACK_WS[ws_key]["Q"][:N] |
|
|
Kb = _SEL_PACK_WS[ws_key]["K"][:N, :L] |
|
|
Vb = _SEL_PACK_WS[ws_key]["V"][:N, :L] |
|
|
|
|
|
map_rows = [] |
|
|
for j, ridx in enumerate(bucket_idx): |
|
|
b, t, g, idx = rows[ridx] |
|
|
Qb[j] = Q[b, t, g] |
|
|
Kb[j] = K[b, g, idx] |
|
|
Vb[j] = V[b, g, idx] |
|
|
map_rows.append((b, t, g)) |
|
|
|
|
|
q_btgh = Qb.unsqueeze(1) |
|
|
q_btgh = q_btgh.permute(0, 2, 1, 3) |
|
|
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk) |
|
|
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1]) |
|
|
attn = F.scaled_dot_product_attention( |
|
|
q_btgh, k_btgh, v_btgh, is_causal=True |
|
|
) |
|
|
Ob = attn.squeeze(2) |
|
|
|
|
|
for j, (b, t, g) in enumerate(map_rows): |
|
|
out[b, t, g] = Ob[j] |
|
|
return out |
|
|
|
|
|
|
|
|
def selection_attention_varlen_all( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
ranges: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Fully batched selection attention using varlen packing across all (B,S,G) rows. |
|
|
|
|
|
If NSA_SEL_VARLEN_V2 is enabled (default), dispatches to the vectorized v2 |
|
|
packer. Otherwise uses the legacy v1 path (minimal loops with workspace). |
|
|
""" |
|
|
|
|
|
if os.getenv("NSA_SEL_VARLEN_V2", "1").lower() in ("1", "true", "yes", "on"): |
|
|
return selection_attention_varlen_all_v2(Q, K, V, ranges) |
|
|
B, S, G, h, Dk = Q.shape |
|
|
|
|
|
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on") |
|
|
if _parity: |
|
|
|
|
|
return grouped_selection_attention_packed(Q, K, V, ranges) |
|
|
device = Q.device |
|
|
Dv = V.shape[-1] |
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device) |
|
|
|
|
|
rows: list[tuple[int, int, int]] = [] |
|
|
lens: list[int] = [] |
|
|
for b in range(B): |
|
|
for t in range(S): |
|
|
for g in range(G): |
|
|
L = 0 |
|
|
for i in range(ranges.shape[3]): |
|
|
s0 = int(ranges[b, t, g, i, 0].item()) |
|
|
e0 = int(ranges[b, t, g, i, 1].item()) |
|
|
if e0 > s0: |
|
|
L += e0 - s0 |
|
|
if L > 0: |
|
|
rows.append((b, t, g)) |
|
|
lens.append(L) |
|
|
N = len(rows) |
|
|
if N == 0: |
|
|
return out |
|
|
|
|
|
total_k = int(sum(lens)) |
|
|
|
|
|
ws = _get_varlen_workspace( |
|
|
device, |
|
|
dtype_q=Q.dtype, |
|
|
dtype_k=K.dtype, |
|
|
dtype_v=V.dtype, |
|
|
h=h, |
|
|
d_k=Dk, |
|
|
d_v=Dv, |
|
|
cap_N=N, |
|
|
cap_total_k=total_k, |
|
|
) |
|
|
q_pack = ws["q"][:N] |
|
|
k_pack = ws["k"][:total_k] |
|
|
v_pack = ws["v"][:total_k] |
|
|
cuq = ws["cuq"][: N + 1] |
|
|
cuk = ws["cuk"][: N + 1] |
|
|
|
|
|
cuq.zero_() |
|
|
cuk.zero_() |
|
|
|
|
|
write_pos = 0 |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
|
|
|
q_pack[i] = Q[b, t, g] |
|
|
|
|
|
for j in range(ranges.shape[3]): |
|
|
s0 = int(ranges[b, t, g, j, 0].item()) |
|
|
e0 = int(ranges[b, t, g, j, 1].item()) |
|
|
if e0 <= s0: |
|
|
continue |
|
|
seg_k = K[b, g, s0:e0] |
|
|
seg_v = V[b, g, s0:e0] |
|
|
Lseg = e0 - s0 |
|
|
|
|
|
_kslice = k_pack[write_pos : write_pos + Lseg] |
|
|
_vslice = v_pack[write_pos : write_pos + Lseg] |
|
|
_kslice.copy_(seg_k[:, None, :].expand_as(_kslice)) |
|
|
_vslice.copy_(seg_v[:, None, :].expand_as(_vslice)) |
|
|
write_pos += Lseg |
|
|
cuq[i + 1] = cuq[i] + 1 |
|
|
cuk[i + 1] = cuk[i] + lens[i] |
|
|
|
|
|
|
|
|
ok, _ = fa2_supported_verbose(device, Q.dtype, Dk) |
|
|
if ok and is_flash_varlen_available(): |
|
|
try: |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=max(lens), |
|
|
causal=_parity, |
|
|
) |
|
|
|
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
out[b, t, g] = o_pack[i] |
|
|
return out |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
buckets: dict[int, list[int]] = {} |
|
|
for i, L in enumerate(lens): |
|
|
buckets.setdefault(L, []).append(i) |
|
|
for L, idxs in buckets.items(): |
|
|
if L <= 0 or len(idxs) == 0: |
|
|
continue |
|
|
Nb = len(idxs) |
|
|
Qb = torch.empty((Nb, h, Dk), dtype=Q.dtype, device=device) |
|
|
Kb = torch.empty((Nb, L, Dk), dtype=K.dtype, device=device) |
|
|
Vb = torch.empty((Nb, L, Dv), dtype=V.dtype, device=device) |
|
|
tgt: list[tuple[int, int, int]] = [] |
|
|
for j, irow in enumerate(idxs): |
|
|
b, t, g = rows[irow] |
|
|
Qb[j] = Q[b, t, g] |
|
|
|
|
|
write = 0 |
|
|
for rj in range(ranges.shape[3]): |
|
|
s0 = int(ranges[b, t, g, rj, 0].item()) |
|
|
e0 = int(ranges[b, t, g, rj, 1].item()) |
|
|
if e0 <= s0: |
|
|
continue |
|
|
Lseg = e0 - s0 |
|
|
Kb[j, write : write + Lseg] = K[b, g, s0:e0] |
|
|
Vb[j, write : write + Lseg] = V[b, g, s0:e0] |
|
|
write += Lseg |
|
|
tgt.append((b, t, g)) |
|
|
|
|
|
try: |
|
|
q_rows = Qb.unsqueeze(1) |
|
|
k_rows = Kb.unsqueeze(2).expand(Nb, L, h, Dk) |
|
|
v_rows = Vb.unsqueeze(2).expand(Nb, L, h, Dv) |
|
|
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=_parity).squeeze( |
|
|
1 |
|
|
) |
|
|
for i, (b, t, g) in enumerate(tgt): |
|
|
out[b, t, g] = Ob[i] |
|
|
except Exception: |
|
|
|
|
|
for j, (b, t, g) in enumerate(tgt): |
|
|
q_btgh = Qb[j].unsqueeze(0).unsqueeze(0) |
|
|
k_btgh = Kb[j].unsqueeze(0).unsqueeze(0) |
|
|
v_btgh = Vb[j].unsqueeze(0).unsqueeze(0) |
|
|
out[b, t, g] = attention_bgh(q_btgh, k_btgh, v_btgh, causal=_parity)[0, 0] |
|
|
return out |
|
|
|
|
|
|
|
|
def selection_attention_varlen_all_v2( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
ranges: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Vectorized v2 varlen selection packer with FA‑2 varlen fast path and dense fallback. |
|
|
- Eliminates Python loops for packing by using a difference-array mask to build per-row |
|
|
allowed indices and flat-select K/V tokens. |
|
|
- Uses causal=False for single‑query rows. |
|
|
- Env: NSA_SEL_VARLEN_MIN_L to bypass on tiny rows (falls back to packed path). |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
|
|
|
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on") |
|
|
if _parity: |
|
|
|
|
|
return grouped_selection_attention_packed(Q, K, V, ranges) |
|
|
device = Q.device |
|
|
Dv = V.shape[-1] |
|
|
S_kv = K.shape[2] |
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device) |
|
|
if S_kv == 0: |
|
|
return out |
|
|
|
|
|
|
|
|
n = ranges.shape[3] |
|
|
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv) |
|
|
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv) |
|
|
BSG = B * S * G |
|
|
starts_f = starts.reshape(BSG, n) |
|
|
ends_f = ends.reshape(BSG, n) |
|
|
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device) |
|
|
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device) |
|
|
diff.scatter_add_(1, starts_f, one) |
|
|
diff.scatter_add_(1, ends_f, -one) |
|
|
allowed = diff[:, :-1].cumsum(dim=1).gt(0) |
|
|
|
|
|
lens_flat = allowed.sum(dim=1, dtype=torch.int32) |
|
|
row_mask = lens_flat.gt(0) |
|
|
if not torch.any(row_mask): |
|
|
return out |
|
|
try: |
|
|
min_L = int(os.getenv("NSA_SEL_VARLEN_MIN_L", "0")) |
|
|
except Exception: |
|
|
min_L = 0 |
|
|
if min_L > 0 and int(lens_flat.max().item()) < min_L: |
|
|
return grouped_selection_attention_packed(Q, K, V, ranges) |
|
|
|
|
|
idx_rows = torch.nonzero(row_mask, as_tuple=False).squeeze(1) |
|
|
N = int(idx_rows.numel()) |
|
|
|
|
|
b_idx = idx_rows // (S * G) |
|
|
rem = idx_rows % (S * G) |
|
|
t_idx = rem // G |
|
|
g_idx = rem % G |
|
|
|
|
|
|
|
|
Q_rows = Q.reshape(B * S * G, h, Dk)[idx_rows] |
|
|
|
|
|
|
|
|
bg_map = ( |
|
|
torch.arange(B, device=device).view(B, 1, 1) * G |
|
|
+ torch.arange(G, device=device).view(1, 1, G) |
|
|
).expand(B, S, G) |
|
|
bg_rows = bg_map.reshape(B * S * G)[idx_rows] |
|
|
K_bg = K.reshape(B * G, S_kv, Dk)[bg_rows] |
|
|
V_bg = V.reshape(B * G, S_kv, Dv)[bg_rows] |
|
|
allowed_rows = allowed[idx_rows] |
|
|
|
|
|
total_k = int(lens_flat[row_mask].sum().item()) |
|
|
sel_k = K_bg[allowed_rows] |
|
|
sel_v = V_bg[allowed_rows] |
|
|
lens_sel = lens_flat[row_mask] |
|
|
|
|
|
|
|
|
ws = _get_varlen_workspace( |
|
|
device, |
|
|
dtype_q=Q.dtype, |
|
|
dtype_k=K.dtype, |
|
|
dtype_v=V.dtype, |
|
|
h=h, |
|
|
d_k=Dk, |
|
|
d_v=Dv, |
|
|
cap_N=N, |
|
|
cap_total_k=total_k, |
|
|
) |
|
|
q_pack = ws["q"][:N] |
|
|
k_pack = ws["k"][:total_k] |
|
|
v_pack = ws["v"][:total_k] |
|
|
cuq = ws["cuq"][: N + 1] |
|
|
cuk = ws["cuk"][: N + 1] |
|
|
|
|
|
q_pack.copy_(Q_rows) |
|
|
k_pack.copy_(sel_k.unsqueeze(1).expand(total_k, h, Dk)) |
|
|
v_pack.copy_(sel_v.unsqueeze(1).expand(total_k, h, Dv)) |
|
|
cuq.copy_(torch.arange(0, N + 1, device=device, dtype=torch.int32)) |
|
|
cuk[0] = 0 |
|
|
torch.cumsum(lens_sel.to(torch.int32), dim=0, out=cuk[1:]) |
|
|
|
|
|
|
|
|
ok, _why = fa2_supported_verbose(device, Q.dtype, Dk) |
|
|
max_len = int(lens_sel.max().item()) |
|
|
if ok and is_flash_varlen_available(): |
|
|
try: |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=max_len, |
|
|
causal=_parity, |
|
|
) |
|
|
out[b_idx, t_idx, g_idx] = o_pack |
|
|
return out |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
return grouped_selection_attention_masked(Q, K, V, ranges) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
starts = cuk[:-1].to(torch.int64) |
|
|
ends = cuk[1:].to(torch.int64) |
|
|
Ls = (ends - starts).to(torch.int64) |
|
|
for L in torch.unique(Ls).tolist(): |
|
|
if L <= 0: |
|
|
continue |
|
|
sel = (Ls == L).nonzero(as_tuple=False).squeeze(1) |
|
|
if sel.numel() == 0: |
|
|
continue |
|
|
Nb = int(sel.numel()) |
|
|
Qb = q_pack[sel] |
|
|
k_rows = torch.empty((Nb, L, h, Dk), dtype=K.dtype, device=device) |
|
|
v_rows = torch.empty((Nb, L, h, Dv), dtype=V.dtype, device=device) |
|
|
for j in range(Nb): |
|
|
s0 = int(starts[sel[j]].item()) |
|
|
e0 = int(ends[sel[j]].item()) |
|
|
k_rows[j] = k_pack[s0:e0] |
|
|
v_rows[j] = v_pack[s0:e0] |
|
|
try: |
|
|
Ob = attention_fa2_dense_batch(Qb.unsqueeze(1), k_rows, v_rows, causal=_parity).squeeze(1) |
|
|
except Exception: |
|
|
Ob = torch.empty((Nb, h, Dv), dtype=V.dtype, device=device) |
|
|
for j in range(Nb): |
|
|
Ob[j] = attention_bgh(Qb[j].unsqueeze(0), k_rows[j].unsqueeze(0), v_rows[j].unsqueeze(0), causal=_parity)[ |
|
|
0 |
|
|
] |
|
|
out[b_idx[sel], t_idx[sel], g_idx[sel]] = Ob |
|
|
return out |
|
|
|
|
|
|
|
|
def grouped_selection_attention_masked( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
ranges: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Fully batched selection attention using an additive -inf mask. |
|
|
Vectorized ranges→mask construction via prefix-sum trick (no Python loops). |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
S_kv = K.shape[2] |
|
|
device = Q.device |
|
|
if S_kv == 0: |
|
|
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device) |
|
|
|
|
|
|
|
|
n = ranges.shape[3] |
|
|
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv) |
|
|
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv) |
|
|
BSG = B * S * G |
|
|
starts_f = starts.reshape(BSG, n) |
|
|
ends_f = ends.reshape(BSG, n) |
|
|
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device) |
|
|
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device) |
|
|
diff.scatter_add_(1, starts_f, one) |
|
|
diff.scatter_add_(1, ends_f, -one) |
|
|
allowed = diff[:, :-1].cumsum(dim=1).gt(0).reshape(B, S, G, S_kv) |
|
|
|
|
|
|
|
|
row_has_any = allowed.any(dim=-1) |
|
|
row_empty = ~row_has_any |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if row_empty.any(): |
|
|
allowed_safe = allowed.clone() |
|
|
flat = allowed_safe.view(B * S * G, S_kv) |
|
|
row_empty_flat = row_empty.reshape(B * S * G) |
|
|
if S_kv > 0: |
|
|
flat[row_empty_flat, 0] = True |
|
|
allowed_safe = flat.view_as(allowed_safe) |
|
|
else: |
|
|
allowed_safe = allowed |
|
|
|
|
|
|
|
|
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() |
|
|
Kf = K.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, Dk).contiguous() |
|
|
Vf = V.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, V.shape[-1]).contiguous() |
|
|
|
|
|
zeros = torch.zeros((B, G * h, S, S_kv), dtype=torch.float32, device=device) |
|
|
neg_inf = torch.full((B, G * h, S, S_kv), float("-inf"), dtype=torch.float32, device=device) |
|
|
Mf = torch.where( |
|
|
allowed_safe.transpose(1, 2) |
|
|
.unsqueeze(2) |
|
|
.expand(-1, -1, h, -1, -1) |
|
|
.reshape(B, G * h, S, S_kv), |
|
|
zeros, |
|
|
neg_inf, |
|
|
).contiguous() |
|
|
|
|
|
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) |
|
|
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1]) |
|
|
|
|
|
if row_empty.any(): |
|
|
Of = torch.where(row_has_any.unsqueeze(-1).unsqueeze(-1), Of, torch.zeros_like(Of)) |
|
|
return Of |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _env_bool(name: str, default: bool = False) -> bool: |
|
|
v = os.getenv(name, "1" if default else "0").lower() |
|
|
return v in ("1", "true", "yes", "on") |
|
|
|
|
|
|
|
|
def _is_sm89(device: torch.device) -> bool: |
|
|
"""Return True if running on CUDA device with SM 8.9 (Ada/RTX 4090).""" |
|
|
if device.type != "cuda": |
|
|
return False |
|
|
try: |
|
|
cap = torch.cuda.get_device_capability(device) |
|
|
return cap == (8, 9) |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def _fa2_forced() -> bool: |
|
|
"""Return True if FA-2 usage is explicitly forced via env.""" |
|
|
return _env_bool("NSA_FA2_FORCE", False) |
|
|
|
|
|
|
|
|
def sliding_window_attention_fa2( |
|
|
Q: torch.Tensor, |
|
|
K: torch.Tensor, |
|
|
V: torch.Tensor, |
|
|
w: int, |
|
|
min_len_for_fa2: int = 16, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Planned FA-2 path for sliding with safe fallbacks. |
|
|
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired. |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
device = Q.device |
|
|
|
|
|
|
|
|
|
|
|
allow_sliding_fa2 = _env_bool("NSA_ALLOW_SLIDING_FA2", False) |
|
|
|
|
|
if _is_sm89(device) and not _fa2_forced(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="win", reason="sm89_guard", forced=bool(_fa2_forced())) |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
|
|
|
if not allow_sliding_fa2 and not ( |
|
|
_env_bool("NSA_FA2_FORCE_VARLEN", False) or _env_bool("NSA_FA2_FORCE_DENSE", False) |
|
|
): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="win", reason="unsupported_sliding_semantics", forced=False) |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
|
|
|
lengths = compute_sliding_lengths(S, w, device) |
|
|
max_len = int(lengths.max().item()) if lengths.numel() > 0 else 0 |
|
|
|
|
|
try: |
|
|
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_WIN", str(min_len_for_fa2))) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if min_len_for_fa2 <= 0: |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="win", reason="disabled_threshold") |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
buckets = build_length_buckets(lengths) |
|
|
if buckets: |
|
|
log("fa2.win.buckets", n=len(buckets), max_len=max_len) |
|
|
|
|
|
for idx in buckets: |
|
|
blens = lengths[idx] |
|
|
_ = build_cu_seqlens_for_buckets(blens) |
|
|
|
|
|
if max_len < min_len_for_fa2: |
|
|
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="win", |
|
|
reason="below_min_len", |
|
|
max_len=int(max_len), |
|
|
min_len=int(min_len_for_fa2), |
|
|
) |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
|
|
|
ok, why = fa2_supported_verbose(device, Q.dtype, Dk) |
|
|
if not ok or not is_flash_varlen_available(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="win", reason=why, has_varlen=is_flash_varlen_available()) |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
|
|
|
try: |
|
|
B, S, G, h, Dk = Q.shape |
|
|
Dv = V.shape[-1] |
|
|
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes") |
|
|
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False) |
|
|
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False) |
|
|
force_win_dense = _env_bool("NSA_WIN_FORCE_DENSE", False) |
|
|
|
|
|
if buckets: |
|
|
uniq, counts = torch.unique(lengths, return_counts=True) |
|
|
log("fa2.win.hist", uniq=uniq.tolist(), counts=counts.tolist()) |
|
|
|
|
|
if (is_flash_varlen_available() and not (force_dense or force_win_dense)) or force_varlen: |
|
|
rows = [] |
|
|
len_rows = [] |
|
|
for t in range(S): |
|
|
L = int(lengths[t].item()) |
|
|
for b in range(B): |
|
|
for g in range(G): |
|
|
rows.append((b, t, g)) |
|
|
len_rows.append(L) |
|
|
N = len(rows) |
|
|
if N > 0 and max_len >= 1: |
|
|
use_safe_pack = ( |
|
|
torch.is_grad_enabled() |
|
|
and (Q.requires_grad or K.requires_grad or V.requires_grad) |
|
|
) or _env_bool("NSA_TRAIN_SAFE_PACK", False) |
|
|
if use_safe_pack: |
|
|
|
|
|
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0) |
|
|
k_rows = [] |
|
|
v_rows = [] |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
L = len_rows[i] |
|
|
if L > 0: |
|
|
start = max(0, (t + 1) - w) |
|
|
end = t + 1 |
|
|
seg_k = K[b, g, start:end].unsqueeze(1).expand(-1, h, -1) |
|
|
seg_v = V[b, g, start:end].unsqueeze(1).expand(-1, h, -1) |
|
|
k_rows.append(seg_k) |
|
|
v_rows.append(seg_v) |
|
|
total_k = int(sum(len_rows)) |
|
|
if total_k > 0: |
|
|
k_pack = torch.cat(k_rows, dim=0) |
|
|
v_pack = torch.cat(v_rows, dim=0) |
|
|
else: |
|
|
k_pack = torch.zeros((0, h, Dk), dtype=K.dtype, device=K.device) |
|
|
v_pack = torch.zeros((0, h, Dv), dtype=V.dtype, device=V.device) |
|
|
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32) |
|
|
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device) |
|
|
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0) |
|
|
else: |
|
|
total_k = int(sum(len_rows)) |
|
|
ws = _get_varlen_workspace( |
|
|
Q.device, Q.dtype, K.dtype, V.dtype, h, Dk, Dv, N, total_k |
|
|
) |
|
|
q_pack = ws["q"][:N] |
|
|
k_pack = ws["k"][:total_k] |
|
|
v_pack = ws["v"][:total_k] |
|
|
|
|
|
cuq = ws["cuq"][: N + 1] |
|
|
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)) |
|
|
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device) |
|
|
cuk = ws["cuk"][: N + 1] |
|
|
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk) |
|
|
|
|
|
write_pos = 0 |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
L = len_rows[i] |
|
|
q_pack[i] = Q[b, t, g] |
|
|
if L > 0: |
|
|
start = max(0, (t + 1) - w) |
|
|
end = t + 1 |
|
|
seg_k = K[b, g, start:end] |
|
|
seg_v = V[b, g, start:end] |
|
|
assert (write_pos + L) <= total_k, "varlen K/V pack overflow" |
|
|
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk) |
|
|
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv) |
|
|
write_pos += L |
|
|
|
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
try: |
|
|
assert cuq.numel() == (N + 1), "cuq length mismatch" |
|
|
assert cuk.numel() == (N + 1), "cuk length mismatch" |
|
|
assert int(cuk[-1].item()) == int(total_k), "cuk total_k mismatch" |
|
|
if total_k > 0 and N > 0: |
|
|
probe = [0, N // 2, N - 1] if N >= 3 else [0] |
|
|
for i in probe: |
|
|
L_i = int(len_rows[i]) |
|
|
b_i, t_i, g_i = rows[i] |
|
|
s_i = int(max(0, (t_i + 1) - w)) |
|
|
e_i = int(t_i + 1) |
|
|
if L_i > 0: |
|
|
ks = k_pack[cuk[i] : cuk[i + 1]] |
|
|
kv = K[b_i, g_i, s_i:e_i].unsqueeze(1).expand(-1, h, -1) |
|
|
if ks.shape != kv.shape: |
|
|
log( |
|
|
"warn.fa2_win_pack_shape", |
|
|
row=i, |
|
|
ks=ks.shape, |
|
|
kv=kv.shape, |
|
|
) |
|
|
else: |
|
|
md = float((ks - kv).abs().max().item()) |
|
|
if md > 1e-3: |
|
|
log( |
|
|
"warn.fa2_win_pack_mismatch", |
|
|
row=i, |
|
|
L=L_i, |
|
|
max_diff=md, |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=max_len, |
|
|
causal=False, |
|
|
) |
|
|
if not torch.isfinite(o_pack).all(): |
|
|
log("warn.fa2_win_varlen_nonfinite") |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.win.varlen_all", N=int(N), total_k=int(total_k), ms=dt) |
|
|
|
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device) |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
out[b, t, g] = o_pack[i] |
|
|
return out |
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device) |
|
|
for idx in buckets: |
|
|
if idx.numel() == 0: |
|
|
continue |
|
|
L = int(lengths[idx[0]].item()) |
|
|
|
|
|
rows_q = [] |
|
|
rows_k = [] |
|
|
rows_v = [] |
|
|
tgt = [] |
|
|
for t in idx.tolist(): |
|
|
start = max(0, (t + 1) - w) |
|
|
end = t + 1 |
|
|
for b in range(B): |
|
|
for g in range(G): |
|
|
rows_q.append(Q[b, t, g]) |
|
|
rows_k.append(K[b, g, start:end]) |
|
|
rows_v.append(V[b, g, start:end]) |
|
|
tgt.append((b, t, g)) |
|
|
if not rows_q: |
|
|
continue |
|
|
N = len(rows_q) |
|
|
Qb = torch.stack(rows_q, dim=0) |
|
|
Kb = torch.stack(rows_k, dim=0) |
|
|
Vb = torch.stack(rows_v, dim=0) |
|
|
if is_flash_varlen_available() and not (force_dense or force_win_dense): |
|
|
|
|
|
q_pack = Qb |
|
|
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk) |
|
|
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv) |
|
|
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32) |
|
|
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32) |
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=L, |
|
|
causal=False, |
|
|
) |
|
|
if not torch.isfinite(o_pack).all(): |
|
|
log("warn.fa2_win_bucket_nonfinite") |
|
|
return sliding_window_attention(Q, K, V, w) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.win.bucket", path="varlen", L=L, N=int(N), ms=dt) |
|
|
Ob = o_pack |
|
|
else: |
|
|
q_rows = Qb.unsqueeze(1) |
|
|
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk) |
|
|
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv) |
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False).squeeze( |
|
|
1 |
|
|
) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.win.bucket", path="dense", L=L, N=int(N), ms=dt) |
|
|
for i, (b, t, g) in enumerate(tgt): |
|
|
out[b, t, g] = Ob[i] |
|
|
return out |
|
|
except Exception as e: |
|
|
log("warn.fa2_unexpected_fallback", branch="win", error=str(e)[:100]) |
|
|
return sliding_window_attention_masked(Q, K, V, w) |
|
|
|
|
|
|
|
|
def compressed_attention_fa2( |
|
|
Q: torch.Tensor, |
|
|
K_cmp: torch.Tensor, |
|
|
V_cmp: torch.Tensor, |
|
|
l: int, |
|
|
d: int, |
|
|
min_len_for_fa2: int = 16, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Planned FA-2 path for compressed with safe fallbacks. |
|
|
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired. |
|
|
""" |
|
|
B, S, G, h, Dk = Q.shape |
|
|
device = Q.device |
|
|
|
|
|
if _is_sm89(device) and not _fa2_forced(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="cmp", reason="sm89_guard", forced=bool(_fa2_forced())) |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
S_cmp = K_cmp.shape[2] |
|
|
if S_cmp == 0: |
|
|
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
num_cmp = compute_compressed_lengths(S, l, d, S_cmp, device) |
|
|
max_len = int(num_cmp.max().item()) if num_cmp.numel() > 0 else 0 |
|
|
try: |
|
|
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_CMP", str(min_len_for_fa2))) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if min_len_for_fa2 <= 0: |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="cmp", reason="disabled_threshold") |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
buckets = build_length_buckets(num_cmp) |
|
|
if buckets: |
|
|
log("fa2.cmp.buckets", n=len(buckets), max_len=max_len) |
|
|
for idx in buckets: |
|
|
blens = num_cmp[idx] |
|
|
_ = build_cu_seqlens_for_buckets(blens) |
|
|
if max_len < min_len_for_fa2: |
|
|
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="cmp", |
|
|
reason="below_min_len", |
|
|
max_len=int(max_len), |
|
|
min_len=int(min_len_for_fa2), |
|
|
) |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
ok, why = fa2_supported_verbose(device, Q.dtype, Dk) |
|
|
if not ok or not is_flash_varlen_available(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="cmp", reason=why, has_varlen=is_flash_varlen_available()) |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
try: |
|
|
Dv = V_cmp.shape[-1] |
|
|
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes") |
|
|
|
|
|
if buckets: |
|
|
uniq, counts = torch.unique(num_cmp, return_counts=True) |
|
|
log("fa2.cmp.hist", uniq=uniq.tolist(), counts=counts.tolist()) |
|
|
|
|
|
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False) |
|
|
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False) |
|
|
if ((is_flash_varlen_available() and not force_dense) or force_varlen) and max_len >= 1: |
|
|
rows = [] |
|
|
len_rows = [] |
|
|
for t in range(S): |
|
|
L = int(num_cmp[t].item()) |
|
|
for b in range(B): |
|
|
for g in range(G): |
|
|
if L > 0: |
|
|
rows.append((b, t, g)) |
|
|
len_rows.append(L) |
|
|
N = len(rows) |
|
|
if N > 0: |
|
|
total_k = int(sum(len_rows)) |
|
|
use_safe_pack = ( |
|
|
torch.is_grad_enabled() |
|
|
and (Q.requires_grad or K_cmp.requires_grad or V_cmp.requires_grad) |
|
|
) or _env_bool("NSA_TRAIN_SAFE_PACK", False) |
|
|
if use_safe_pack: |
|
|
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0) |
|
|
k_rows = [] |
|
|
v_rows = [] |
|
|
for (b, t, g), L in zip(rows, len_rows): |
|
|
if L > 0: |
|
|
seg_k = K_cmp[b, g, :L] |
|
|
seg_v = V_cmp[b, g, :L] |
|
|
k_rows.append(seg_k.unsqueeze(1).expand(-1, h, -1)) |
|
|
v_rows.append(seg_v.unsqueeze(1).expand(-1, h, -1)) |
|
|
if total_k > 0: |
|
|
k_pack = torch.cat(k_rows, dim=0) |
|
|
v_pack = torch.cat(v_rows, dim=0) |
|
|
else: |
|
|
k_pack = torch.zeros((0, h, Dk), dtype=K_cmp.dtype, device=K_cmp.device) |
|
|
v_pack = torch.zeros((0, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32) |
|
|
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device) |
|
|
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0) |
|
|
else: |
|
|
ws = _get_varlen_workspace( |
|
|
Q.device, Q.dtype, K_cmp.dtype, V_cmp.dtype, h, Dk, Dv, N, total_k |
|
|
) |
|
|
q_pack = ws["q"][:N] |
|
|
k_pack = ws["k"][:total_k] |
|
|
v_pack = ws["v"][:total_k] |
|
|
cuq = ws["cuq"][: N + 1] |
|
|
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)) |
|
|
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device) |
|
|
cuk = ws["cuk"][: N + 1] |
|
|
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk) |
|
|
write_pos = 0 |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
L = len_rows[i] |
|
|
q_pack[i] = Q[b, t, g] |
|
|
if L > 0: |
|
|
seg_k = K_cmp[b, g, :L] |
|
|
seg_v = V_cmp[b, g, :L] |
|
|
assert (write_pos + L) <= total_k, "varlen cmp K/V pack overflow" |
|
|
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk) |
|
|
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv) |
|
|
write_pos += L |
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=max_len, |
|
|
causal=False, |
|
|
) |
|
|
if not torch.isfinite(o_pack).all(): |
|
|
log("warn.fa2_cmp_varlen_nonfinite") |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.cmp.varlen_all", N=int(N), total_k=int(total_k), ms=dt) |
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
for i, (b, t, g) in enumerate(rows): |
|
|
out[b, t, g] = o_pack[i] |
|
|
return out |
|
|
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
for idx in buckets: |
|
|
if idx.numel() == 0: |
|
|
continue |
|
|
L = int(num_cmp[idx[0]].item()) |
|
|
rows_q = [] |
|
|
rows_k = [] |
|
|
rows_v = [] |
|
|
tgt = [] |
|
|
for t in idx.tolist(): |
|
|
if L <= 0: |
|
|
continue |
|
|
for b in range(B): |
|
|
for g in range(G): |
|
|
rows_q.append(Q[b, t, g]) |
|
|
rows_k.append(K_cmp[b, g, :L]) |
|
|
rows_v.append(V_cmp[b, g, :L]) |
|
|
tgt.append((b, t, g)) |
|
|
if not rows_q: |
|
|
continue |
|
|
N = len(rows_q) |
|
|
Qb = torch.stack(rows_q, dim=0) |
|
|
Kb = torch.stack(rows_k, dim=0) |
|
|
Vb = torch.stack(rows_v, dim=0) |
|
|
if is_flash_varlen_available() and not force_dense: |
|
|
q_pack = Qb |
|
|
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk) |
|
|
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv) |
|
|
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32) |
|
|
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32) |
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
o_pack = attention_fa2_varlen( |
|
|
q_pack, |
|
|
k_pack, |
|
|
v_pack, |
|
|
cuq, |
|
|
cuk, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=L, |
|
|
causal=False, |
|
|
) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.cmp.bucket", path="varlen", L=L, N=int(N), ms=dt) |
|
|
Ob = o_pack |
|
|
else: |
|
|
q_rows = Qb.unsqueeze(1) |
|
|
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk) |
|
|
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv) |
|
|
if use_timing: |
|
|
t0 = time.perf_counter() |
|
|
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=True).squeeze(1) |
|
|
if not torch.isfinite(Ob).all(): |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
if use_timing: |
|
|
dt = (time.perf_counter() - t0) * 1e3 |
|
|
log("fa2.cmp.bucket", path="dense", L=L, N=int(N), ms=dt) |
|
|
for i, (b, t, g) in enumerate(tgt): |
|
|
out[b, t, g] = Ob[i] |
|
|
return out |
|
|
except Exception as e: |
|
|
log("warn.fa2_unexpected_fallback", branch="cmp", error=str(e)[:100]) |
|
|
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d) |
|
|
|
|
|
|
|
|
def sliding_window_attention_fa2_decode( |
|
|
q_t: torch.Tensor, K_win: torch.Tensor, V_win: torch.Tensor, w: int |
|
|
) -> torch.Tensor: |
|
|
B, G, h, Dk = q_t.shape |
|
|
|
|
|
if _is_sm89(q_t.device) and not _fa2_forced(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="win.decode", |
|
|
reason="sm89_guard", |
|
|
forced=bool(_fa2_forced()), |
|
|
) |
|
|
end = K_win.shape[2] |
|
|
win_len = min(w, end) |
|
|
if win_len == 0: |
|
|
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device) |
|
|
start = end - win_len |
|
|
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True) |
|
|
end = K_win.shape[2] |
|
|
win_len = min(w, end) |
|
|
if win_len == 0: |
|
|
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device) |
|
|
|
|
|
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk) |
|
|
if not ok: |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="win.decode", reason=why) |
|
|
start = end - win_len |
|
|
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True) |
|
|
|
|
|
try: |
|
|
min_len = int(os.getenv("NSA_FA2_MIN_LEN_WIN", "16")) |
|
|
except Exception: |
|
|
min_len = 16 |
|
|
if min_len < 1: |
|
|
min_len = 1 |
|
|
if win_len < min_len: |
|
|
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="win.decode", |
|
|
reason="below_min_len", |
|
|
win_len=int(win_len), |
|
|
min_len=int(min_len), |
|
|
) |
|
|
start = end - win_len |
|
|
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True) |
|
|
start = end - win_len |
|
|
k = K_win[:, :, start:end] |
|
|
v = V_win[:, :, start:end] |
|
|
N = B * G |
|
|
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1) |
|
|
k_rows = k.reshape(N, win_len, Dk).unsqueeze(2).expand(N, win_len, h, Dk) |
|
|
v_rows = v.reshape(N, win_len, v.shape[-1]).unsqueeze(2).expand(N, win_len, h, v.shape[-1]) |
|
|
try: |
|
|
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False) |
|
|
o = o.squeeze(1).reshape(B, G, h, -1) |
|
|
if not torch.isfinite(o).all(): |
|
|
return attention_bgh(q_t, k, v, causal=True) |
|
|
return o |
|
|
except Exception as e: |
|
|
log("warn.fa2_unexpected_fallback", branch="win.decode", error=str(e)[:100]) |
|
|
return attention_bgh(q_t, k, v, causal=True) |
|
|
|
|
|
|
|
|
def compressed_attention_fa2_decode( |
|
|
q_t: torch.Tensor, K_cmp: torch.Tensor, V_cmp: torch.Tensor, L: int |
|
|
) -> torch.Tensor: |
|
|
if L <= 0: |
|
|
B, G, h, _ = q_t.shape |
|
|
return torch.zeros((B, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device) |
|
|
B, G, h, Dk = q_t.shape |
|
|
|
|
|
if _is_sm89(q_t.device) and not _fa2_forced(): |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="cmp.decode", |
|
|
reason="sm89_guard", |
|
|
forced=bool(_fa2_forced()), |
|
|
) |
|
|
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True) |
|
|
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk) |
|
|
if not ok: |
|
|
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"): |
|
|
log("fa2.gate_skip", branch="cmp.decode", reason=why) |
|
|
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True) |
|
|
try: |
|
|
min_len = int(os.getenv("NSA_FA2_MIN_LEN_CMP", "16")) |
|
|
except Exception: |
|
|
min_len = 16 |
|
|
if min_len < 1: |
|
|
min_len = 1 |
|
|
if L < min_len: |
|
|
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"): |
|
|
log( |
|
|
"fa2.gate_skip", |
|
|
branch="cmp.decode", |
|
|
reason="below_min_len", |
|
|
L=int(L), |
|
|
min_len=int(min_len), |
|
|
) |
|
|
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True) |
|
|
k = K_cmp[:, :, :L] |
|
|
v = V_cmp[:, :, :L] |
|
|
N = B * G |
|
|
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1) |
|
|
k_rows = k.reshape(N, L, Dk).unsqueeze(2).expand(N, L, h, Dk) |
|
|
v_rows = v.reshape(N, L, v.shape[-1]).unsqueeze(2).expand(N, L, h, v.shape[-1]) |
|
|
try: |
|
|
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False) |
|
|
o = o.squeeze(1).reshape(B, G, h, -1) |
|
|
if not torch.isfinite(o).all(): |
|
|
return attention_bgh(q_t, k, v, causal=True) |
|
|
return o |
|
|
except Exception: |
|
|
return attention_bgh(q_t, k, v, causal=True) |
|
|
|