dragon / nsa_utils.py
alexandretl's picture
CCE | Gate attn | ZCG | RoPE GDN | GQA GDN | uniconv GDN ||CCA | NSA | PLT (not tested) | DMA fix | SWR
959cbe5
from typing import Optional
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention)
from flash_attn import flash_attn_func
import triton
import triton.language as tl
import triton.testing
from fla.ops.nsa.parallel import ParallelNSAFunction
from fla.ops.utils.pooling import mean_pooling
from fla.ops.nsa.parallel import parallel_nsa_topk
def compression_attention(q, k_cmp, v_cmp, block_mask):
o_cmp, lse_cmp = flex_attention(
q.transpose(1, 2),
k_cmp.transpose(1, 2),
v_cmp.transpose(1, 2),
block_mask=block_mask,
enable_gqa=True,
return_lse=True
)
return o_cmp.transpose(1, 2), lse_cmp
# Autotune configurations for the forward kernel
_sel_attn_fwd_configs = [
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
]
# Autotune configurations for the backward preprocess kernel
_sel_attn_bwd_preprocess_configs = [
triton.Config({'BLOCK_M': 16, 'num_stages': 1, 'num_warps': 4}, num_ctas=1),
triton.Config({'BLOCK_M': 32, 'num_stages': 1, 'num_warps': 4}, num_ctas=1),
triton.Config({'BLOCK_M': 16, 'num_stages': 2, 'num_warps': 4}, num_ctas=1),
triton.Config({'BLOCK_M': 32, 'num_stages': 2, 'num_warps': 4}, num_ctas=1),
triton.Config({'BLOCK_M': 16, 'num_stages': 1, 'num_warps': 8}, num_ctas=1),
triton.Config({'BLOCK_M': 32, 'num_stages': 1, 'num_warps': 8}, num_ctas=1),
]
# Autotune configurations for the main backward kernel
_sel_attn_bwd_configs = [
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
]
@triton.autotune( # Decorate the kernel
configs=_sel_attn_fwd_configs,
key=['M', 'N', 'D', 'SELECTION_BLOCK_SIZE', 'T', 'HEADS_PER_GROUP', 'causal'],
)
@triton.jit
def _sel_attn_fwd_kernel(
Q: tl.tensor,
K: tl.tensor,
V: tl.tensor,
Top_idx: tl.tensor,
softmax_scale: tl.constexpr,
causal: tl.constexpr,
Out: tl.tensor,
Lse: tl.tensor,
stride_qb, stride_qh, stride_qm, stride_qd,
stride_kb, stride_kg, stride_kn, stride_kd,
stride_vb, stride_vg, stride_vn, stride_vd,
stride_tb, stride_tg, stride_tm, stride_tt,
stride_ob, stride_oh, stride_om, stride_od,
stride_lb, stride_lh, stride_lm,
B: tl.constexpr,
H: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
T: tl.constexpr,
DP: tl.constexpr,
SELECTION_BLOCK_SIZE: tl.constexpr,
HEADS_PER_GROUP: tl.constexpr,
OFFSET_M: tl.constexpr,
BLOCK_H: tl.constexpr,
):
stride_hg = stride_qh * HEADS_PER_GROUP
b = tl.program_id(0)
m = tl.program_id(1) + OFFSET_M
g = tl.program_id(2)
# Base pointers
q_base = Q + b * stride_qb + m * stride_qm + g * stride_hg
k_base = K + b * stride_kb + g * stride_kg
v_base = V + b * stride_vb + g * stride_vg
t_base = Top_idx + b * stride_tb + m * stride_tm + g * stride_tg
o_base = Out + b * stride_ob + m * stride_om + g * stride_hg
l_base = Lse + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP
# Offsets
offs_h = tl.arange(0, BLOCK_H)
mask_h = offs_h < HEADS_PER_GROUP
offs_d = tl.arange(0, DP)
mask_d = offs_d < D
offs_n = tl.arange(0, SELECTION_BLOCK_SIZE)
q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
q_blck = tl.load(q_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0) # Keep as float16
max_log = tl.full([BLOCK_H], float('-inf'), dtype=tl.float32)
sum_exp = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
accum = tl.zeros([BLOCK_H, DP], dtype=tl.float32)
# 1/ln(2) = 1.44269504
# log_scale = softmax_scale * 1.44269504
max_col = max(0, N - M + m) if causal else N
for idx in range(T):
# NOTE: Ideally we load top_idx outside the loop, this can be done with a gather which will
# supported in future versions of Triton
top = tl.load(t_base + idx * stride_tt)
col = top * SELECTION_BLOCK_SIZE
col = tl.multiple_of(col, SELECTION_BLOCK_SIZE)
if not causal or (col <= max_col and col >= 0):
cols = col + offs_n
mask_n = cols < N
k_ptrs = k_base + offs_d[:, None] * stride_kd + cols[None, :] * stride_kn
k_blck = tl.load(k_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0)
v_ptrs = v_base + cols[:, None] * stride_vn + offs_d[None, :] * stride_vd
v_blck = tl.load(v_ptrs, mask=mask_d[None, :] & mask_n[:, None], other=0.0).to(tl.float32)
# qk = tl.dot(q_blck, k_blck) * log_scale # [BH, BN]
qk = tl.dot(q_blck, k_blck) * softmax_scale # [BH, BN]
# NOTE: We can move the multiplication by softmax_scale outside the loop
causal_mask = cols <= max_col
qk = tl.where(causal_mask[None, :], qk, float('-inf'))
# stable mx-log-sum-exp
new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH]
# exp_qk = tl.math.exp2(qk - new_max[:, None]) # [BH, BN]
exp_qk = tl.math.exp(qk - new_max[:, None]) # [BH, BN]
sum_qk = tl.sum(exp_qk, axis=1) # [BH]
# alpha = tl.math.exp2(max_log - new_max) # [BH]
alpha = tl.math.exp(max_log - new_max) # [BH]
sum_exp = sum_exp * alpha + sum_qk # [BH]
accum = accum * alpha[:, None] # [BH, DP]
accum = tl.dot(exp_qk, v_blck, accum) # [BH, DP]
max_log = new_max
# epilog
# fin_log = max_log + tl.math.log2(sum_exp) # [BH]
# fin_log *= 0.69314718
fin_log = max_log + tl.math.log(sum_exp) # [BH]
out_vals = accum / sum_exp[:, None] # [BH, DP]
o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
tl.store(o_ptrs, out_vals, mask=mask_h[:, None] & mask_d[None, :])
l_ptrs = l_base + offs_h * stride_lh
tl.store(l_ptrs, fin_log, mask=mask_h)
@triton.autotune(
configs=_sel_attn_bwd_preprocess_configs,
key=['M', 'D', 'H'],
)
@triton.jit
def _sel_attn_bwd_preprocess_kernel(
Out,
DOut,
Delta,
stride_ob, stride_oh, stride_om, stride_od,
stride_dob, stride_doh, stride_dom, stride_dod,
stride_db, stride_dh, stride_dm,
B: tl.constexpr,
H: tl.constexpr,
M: tl.constexpr,
D: tl.constexpr,
DP: tl.constexpr,
BLOCK_M: tl.constexpr,
):
# program indices
m = tl.program_id(0)
bh = tl.program_id(1)
b = bh // H
h = bh % H
# Base pointers
o_base = Out + b * stride_ob + h * stride_oh
do_base = DOut + b * stride_dob + h * stride_doh
# Offsets
offs_m = m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, DP)
o_ptrs = o_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
do_ptrs = do_base + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod
mask = (offs_m[:, None] < M) & (offs_d[None, :] < D)
o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32)
delta = tl.sum(o * do, axis=1)
delta_ptr = Delta + b * stride_db + h * stride_dh + offs_m * stride_dm
tl.store(delta_ptr, delta, mask=offs_m < M)
@triton.autotune(
configs=_sel_attn_bwd_configs,
key=['M', 'N', 'D', 'SELECTION_BLOCK_SIZE', 'T', 'HEADS_PER_GROUP', 'causal'],
reset_to_zero=['DK', 'DV']
)
@triton.jit
def _sel_attn_bwd_kernel(
Q: tl.tensor,
K: tl.tensor,
V: tl.tensor,
Top_idx: tl.tensor,
Lse: tl.tensor,
DOut: tl.tensor,
Delta: tl.tensor,
softmax_scale: tl.constexpr,
causal: tl.constexpr,
DQ: tl.tensor,
DK: tl.tensor,
DV: tl.tensor,
stride_qb, stride_qh, stride_qm, stride_qd,
stride_kb, stride_kg, stride_kn, stride_kd,
stride_vb, stride_vg, stride_vn, stride_vd,
stride_tb, stride_tg, stride_tm, stride_tt,
stride_ob, stride_oh, stride_om, stride_od,
stride_lb, stride_lh, stride_lm,
B: tl.constexpr,
H: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
T: tl.constexpr,
DP: tl.constexpr,
SELECTION_BLOCK_SIZE: tl.constexpr,
HEADS_PER_GROUP: tl.constexpr,
OFFSET_M: tl.constexpr,
BLOCK_H: tl.constexpr,
):
# NOTE: Should we move this outside?
stride_hg = stride_qh * HEADS_PER_GROUP
b = tl.program_id(0)
m = tl.program_id(1) + OFFSET_M
g = tl.program_id(2)
# Base pointers
q_base = Q + b * stride_qb + m * stride_qm + g * stride_hg
k_base = K + b * stride_kb + g * stride_kg
v_base = V + b * stride_vb + g * stride_vg
t_base = Top_idx + b * stride_tb + m * stride_tm + g * stride_tg
l_base = Lse + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP
do_base = DOut + b * stride_ob + m * stride_om + g * stride_hg
d_base = Delta + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP
dq_base = DQ + b * stride_qb + m * stride_qm + g * stride_hg
dk_base = DK + b * stride_kb + g * stride_kg
dv_base = DV + b * stride_vb + g * stride_vg
# Offsets
offs_h = tl.arange(0, BLOCK_H)
mask_h = offs_h < HEADS_PER_GROUP
offs_d = tl.arange(0, DP)
mask_d = offs_d < D
offs_n = tl.arange(0, SELECTION_BLOCK_SIZE)
q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
q_blck = tl.load(q_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0).to(tl.float32) # [BH, DP]
do_ptrs = do_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
do_blck = tl.load(do_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0).to(tl.float32) # [BH, DP]
l_ptrs = l_base + offs_h * stride_lh
l_blck = tl.load(l_ptrs, mask=mask_h, other=0.0) # [BH]
d_ptrs = d_base + offs_h * stride_lh
d_blck = tl.load(d_ptrs, mask=mask_h, other=0.0) # [BH]
accum = tl.zeros([BLOCK_H, DP], dtype=tl.float32)
# 1/ln(2) = 1.44269504
log_scale = softmax_scale * 1.44269504
max_col = max(0, N - M + m) if causal else N
for idx in range(T):
# NOTE: Ideally we load top_idx outside the loop, this can be done with a gather which will
# supported in future versions of Triton
top = tl.load(t_base + idx * stride_tt)
col = top * SELECTION_BLOCK_SIZE
col = tl.multiple_of(col, SELECTION_BLOCK_SIZE)
if not causal or col <= max_col:
cols = col + offs_n
mask_n = cols < N
k_ptrs = k_base + cols[None, :] * stride_kn + offs_d[:, None] * stride_kd
k_blck = tl.load(k_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0).to(tl.float32) # [DP, BN]
qk = tl.dot(q_blck, k_blck) * log_scale
causal_mask = cols <= max_col
qk = tl.where(causal_mask[None, :], qk, -1e6)
l2 = l_blck * 1.44269504
exp_qk = tl.math.exp2(qk - l2[:, None]) # [BH, BN]
dv_inc = tl.dot(tl.trans(exp_qk), do_blck) # [BN, DP]
dv_ptrs = dv_base + cols[:, None] * stride_vn + offs_d[None, :] * stride_vd
# [BN, DP]
tl.atomic_add(dv_ptrs, dv_inc.to(tl.float32), mask=mask_d[None, :] & mask_n[:, None], sem="release", scope="gpu")
v_ptrs = v_base + cols[None, :] * stride_vn + offs_d[:, None] * stride_vd
v_blck = tl.load(v_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0).to(tl.float32) # [D, BN]
dp = tl.dot(do_blck, v_blck) # [BH, BN]
ds2 = exp_qk * (dp - d_blck[:, None]) # [BH, BN]
ds = ds2 * softmax_scale
accum = tl.dot(ds, tl.trans(k_blck), acc=accum) # [BH, DP]
dk_inc = tl.dot(tl.trans(ds), q_blck) # [BN, DP]
dk_ptrs = dk_base + cols[:, None] * stride_kn + offs_d[None, :] * stride_kd
tl.atomic_add(dk_ptrs, dk_inc.to(tl.float32), mask=mask_d[None, :] & mask_n[:, None], sem="release", scope="gpu")
dq_ptrs = dq_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
tl.store(dq_ptrs, accum, mask=mask_h[:, None] & mask_d[None, :])
class SelectionAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q, k, v, top_idx,
selection_block_size,
softmax_scale=None,
causal=False,
return_attn_probs=False
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
B, M, H, D = q.shape
_, N, G, _ = k.shape
_, _, _, T = top_idx.shape
assert q.shape == (B, M, H, D)
assert k.shape == (B, N, G, D)
assert v.shape == (B, N, G, D)
assert top_idx.shape == (B, M, G, T)
if softmax_scale is None:
softmax_scale = 1.0 / (D ** 0.5)
# NOTE: Is it faster to only set the untouched elements?
out = torch.zeros_like(q)
lse = torch.full((B, H, M), float('-inf'), device=q.device, dtype=torch.float32)
DP = triton.next_power_of_2(D)
HEADS_PER_GROUP = H // G
OFFSET_M = max(0, M - N) if causal else 0
BLOCK_H = max(16, HEADS_PER_GROUP)
grid = (B, M - OFFSET_M, G)
_sel_attn_fwd_kernel[grid](
q, k, v, top_idx,
softmax_scale, causal,
out, lse,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(1), k.stride(3),
v.stride(0), v.stride(2), v.stride(1), v.stride(3),
top_idx.stride(0), top_idx.stride(2), top_idx.stride(1), top_idx.stride(3),
out.stride(0), out.stride(2), out.stride(1), out.stride(3),
lse.stride(0), lse.stride(1), lse.stride(2),
B, H, M, N, D, T, DP,
SELECTION_BLOCK_SIZE=selection_block_size,
HEADS_PER_GROUP=HEADS_PER_GROUP,
OFFSET_M=OFFSET_M,
BLOCK_H=BLOCK_H,
)
ctx.save_for_backward(q, k, v, top_idx, out, lse)
ctx.selection_block_size = selection_block_size
ctx.softmax_scale = softmax_scale
ctx.causal = causal
if return_attn_probs:
return out, lse
else:
return out
@staticmethod
def backward(ctx, *grad_outputs):
d_out = grad_outputs[0]
q, k, v, top_idx, out, lse = ctx.saved_tensors
B, M, H, D = q.shape
_, N, G, _ = k.shape
_, _, _, T = top_idx.shape
assert d_out.shape == (B, M, H, D)
selection_block_size = ctx.selection_block_size
softmax_scale = ctx.softmax_scale
causal = ctx.causal
delta = torch.empty_like(lse)
DP = triton.next_power_of_2(D)
HEADS_PER_GROUP = H // G
OFFSET_M = max(0, M - N) if causal else 0
BLOCK_H = max(16, HEADS_PER_GROUP)
def grid_preprocess(META):
return (triton.cdiv(M, META['BLOCK_M']), B * H)
_sel_attn_bwd_preprocess_kernel[grid_preprocess](
out, d_out, delta,
out.stride(0), out.stride(2), out.stride(1), out.stride(3),
d_out.stride(0), d_out.stride(2), d_out.stride(1), d_out.stride(3),
delta.stride(0), delta.stride(1), delta.stride(2),
B, H, M, D, DP,
)
dq = torch.empty_like(q, dtype=q.dtype)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
grid_bwd = (B, M - OFFSET_M, G)
_sel_attn_bwd_kernel[grid_bwd](
q, k, v, top_idx, lse,
d_out, delta,
softmax_scale, causal,
dq, dk, dv,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(1), k.stride(3),
v.stride(0), v.stride(2), v.stride(1), v.stride(3),
top_idx.stride(0), top_idx.stride(2), top_idx.stride(1), top_idx.stride(3),
out.stride(0), out.stride(2), out.stride(1), out.stride(3),
lse.stride(0), lse.stride(1), lse.stride(2),
B, H, M, N, D, T, DP,
SELECTION_BLOCK_SIZE=selection_block_size,
HEADS_PER_GROUP=HEADS_PER_GROUP,
OFFSET_M=OFFSET_M,
BLOCK_H=BLOCK_H,
)
return dq, dk.to(k.dtype), dv.to(v.dtype), None, None, None, None, None
def selection_attention(
q, k, v,
block_indices, block_count, block_size, scale,
variant='two-pass', # 'one-pass' or 'two-pass'
causal=True,
return_attn_probs=False
):
if variant == 'one-pass':
return SelectionAttention.apply(
q, k, v, block_indices, block_size, scale, causal, return_attn_probs
)
elif variant == 'two-pass':
# FLA Backend for two-pass selection attention
return ParallelNSAFunction.apply(
q, k, v, block_indices, block_count, block_size, scale, None
)
else:
raise ValueError(f"Invalid variant: {variant}")
def nsa_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_cmp: Optional[torch.Tensor] = None,
g_slc: Optional[torch.Tensor] = None,
g_swa: Optional[torch.Tensor] = None,
block_count: int = 16,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
) -> torch.Tensor:
B, M, H, D = q.shape
_, N, G, _ = k.shape
assert g_cmp is not None and g_slc is not None and g_swa is not None, "g_cmp, g_slc, and g_swa are required"
assert k.shape == (B, N, G, D), f"k shape: {k.shape} must be ({B}, {N}, {G}, {D})"
assert v.shape == (B, N, G, D), f"v shape: {v.shape} must be ({B}, {N}, {G}, {D})"
assert g_cmp.shape == (B, M, H), f"g_cmp shape: {g_cmp.shape} must be ({B}, {M}, {H})"
assert g_slc.shape == (B, M, H), f"g_slc shape: {g_slc.shape} must be ({B}, {M}, {H})"
assert g_swa.shape == (B, M, H), f"g_swa shape: {g_swa.shape} must be ({B}, {M}, {H})"
if scale is None:
scale = D ** -0.5
k_cmp, v_cmp = mean_pooling(k, block_size), mean_pooling(v, block_size)
def cmp_mask(b, h, q_idx, kv_idx):
return q_idx <= (kv_idx + 1) * block_size - 1
block_mask = create_block_mask(cmp_mask, B, H, M, N//block_size)
o_cmp, lse_cmp = compression_attention(q, k_cmp, v_cmp, block_mask)
block_indices = parallel_nsa_topk(
q=q,
k=k_cmp,
lse=lse_cmp,
block_counts=block_count,
block_size=block_size,
scale=scale,
cu_seqlens=None
)
o_slc = selection_attention(
q, k, v, block_indices, block_count, block_size, scale
)
o_swd = flash_attn_func(
q, k, v,
causal=True,
window_size=(window_size-1, 0)
)
o = o_cmp * g_cmp.unsqueeze(-1) + o_slc * g_slc.unsqueeze(-1) + o_swd * g_swa.unsqueeze(-1)
return o