CCE | Gate attn | ZCG | RoPE GDN | GQA GDN | uniconv GDN ||CCA | NSA | PLT (not tested) | DMA fix | SWR
959cbe5 | from typing import Optional | |
| import torch | |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask | |
| flex_attention = torch.compile(flex_attention) | |
| from flash_attn import flash_attn_func | |
| import triton | |
| import triton.language as tl | |
| import triton.testing | |
| from fla.ops.nsa.parallel import ParallelNSAFunction | |
| from fla.ops.utils.pooling import mean_pooling | |
| from fla.ops.nsa.parallel import parallel_nsa_topk | |
| def compression_attention(q, k_cmp, v_cmp, block_mask): | |
| o_cmp, lse_cmp = flex_attention( | |
| q.transpose(1, 2), | |
| k_cmp.transpose(1, 2), | |
| v_cmp.transpose(1, 2), | |
| block_mask=block_mask, | |
| enable_gqa=True, | |
| return_lse=True | |
| ) | |
| return o_cmp.transpose(1, 2), lse_cmp | |
| # Autotune configurations for the forward kernel | |
| _sel_attn_fwd_configs = [ | |
| triton.Config({}, num_warps=num_warps) | |
| for num_warps in [1, 2, 4, 8] | |
| ] | |
| # Autotune configurations for the backward preprocess kernel | |
| _sel_attn_bwd_preprocess_configs = [ | |
| triton.Config({'BLOCK_M': 16, 'num_stages': 1, 'num_warps': 4}, num_ctas=1), | |
| triton.Config({'BLOCK_M': 32, 'num_stages': 1, 'num_warps': 4}, num_ctas=1), | |
| triton.Config({'BLOCK_M': 16, 'num_stages': 2, 'num_warps': 4}, num_ctas=1), | |
| triton.Config({'BLOCK_M': 32, 'num_stages': 2, 'num_warps': 4}, num_ctas=1), | |
| triton.Config({'BLOCK_M': 16, 'num_stages': 1, 'num_warps': 8}, num_ctas=1), | |
| triton.Config({'BLOCK_M': 32, 'num_stages': 1, 'num_warps': 8}, num_ctas=1), | |
| ] | |
| # Autotune configurations for the main backward kernel | |
| _sel_attn_bwd_configs = [ | |
| triton.Config({}, num_warps=num_warps) | |
| for num_warps in [1, 2, 4, 8] | |
| ] | |
| def _sel_attn_fwd_kernel( | |
| Q: tl.tensor, | |
| K: tl.tensor, | |
| V: tl.tensor, | |
| Top_idx: tl.tensor, | |
| softmax_scale: tl.constexpr, | |
| causal: tl.constexpr, | |
| Out: tl.tensor, | |
| Lse: tl.tensor, | |
| stride_qb, stride_qh, stride_qm, stride_qd, | |
| stride_kb, stride_kg, stride_kn, stride_kd, | |
| stride_vb, stride_vg, stride_vn, stride_vd, | |
| stride_tb, stride_tg, stride_tm, stride_tt, | |
| stride_ob, stride_oh, stride_om, stride_od, | |
| stride_lb, stride_lh, stride_lm, | |
| B: tl.constexpr, | |
| H: tl.constexpr, | |
| M: tl.constexpr, | |
| N: tl.constexpr, | |
| D: tl.constexpr, | |
| T: tl.constexpr, | |
| DP: tl.constexpr, | |
| SELECTION_BLOCK_SIZE: tl.constexpr, | |
| HEADS_PER_GROUP: tl.constexpr, | |
| OFFSET_M: tl.constexpr, | |
| BLOCK_H: tl.constexpr, | |
| ): | |
| stride_hg = stride_qh * HEADS_PER_GROUP | |
| b = tl.program_id(0) | |
| m = tl.program_id(1) + OFFSET_M | |
| g = tl.program_id(2) | |
| # Base pointers | |
| q_base = Q + b * stride_qb + m * stride_qm + g * stride_hg | |
| k_base = K + b * stride_kb + g * stride_kg | |
| v_base = V + b * stride_vb + g * stride_vg | |
| t_base = Top_idx + b * stride_tb + m * stride_tm + g * stride_tg | |
| o_base = Out + b * stride_ob + m * stride_om + g * stride_hg | |
| l_base = Lse + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP | |
| # Offsets | |
| offs_h = tl.arange(0, BLOCK_H) | |
| mask_h = offs_h < HEADS_PER_GROUP | |
| offs_d = tl.arange(0, DP) | |
| mask_d = offs_d < D | |
| offs_n = tl.arange(0, SELECTION_BLOCK_SIZE) | |
| q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd | |
| q_blck = tl.load(q_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0) # Keep as float16 | |
| max_log = tl.full([BLOCK_H], float('-inf'), dtype=tl.float32) | |
| sum_exp = tl.full([BLOCK_H], 1.0, dtype=tl.float32) | |
| accum = tl.zeros([BLOCK_H, DP], dtype=tl.float32) | |
| # 1/ln(2) = 1.44269504 | |
| # log_scale = softmax_scale * 1.44269504 | |
| max_col = max(0, N - M + m) if causal else N | |
| for idx in range(T): | |
| # NOTE: Ideally we load top_idx outside the loop, this can be done with a gather which will | |
| # supported in future versions of Triton | |
| top = tl.load(t_base + idx * stride_tt) | |
| col = top * SELECTION_BLOCK_SIZE | |
| col = tl.multiple_of(col, SELECTION_BLOCK_SIZE) | |
| if not causal or (col <= max_col and col >= 0): | |
| cols = col + offs_n | |
| mask_n = cols < N | |
| k_ptrs = k_base + offs_d[:, None] * stride_kd + cols[None, :] * stride_kn | |
| k_blck = tl.load(k_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0) | |
| v_ptrs = v_base + cols[:, None] * stride_vn + offs_d[None, :] * stride_vd | |
| v_blck = tl.load(v_ptrs, mask=mask_d[None, :] & mask_n[:, None], other=0.0).to(tl.float32) | |
| # qk = tl.dot(q_blck, k_blck) * log_scale # [BH, BN] | |
| qk = tl.dot(q_blck, k_blck) * softmax_scale # [BH, BN] | |
| # NOTE: We can move the multiplication by softmax_scale outside the loop | |
| causal_mask = cols <= max_col | |
| qk = tl.where(causal_mask[None, :], qk, float('-inf')) | |
| # stable mx-log-sum-exp | |
| new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH] | |
| # exp_qk = tl.math.exp2(qk - new_max[:, None]) # [BH, BN] | |
| exp_qk = tl.math.exp(qk - new_max[:, None]) # [BH, BN] | |
| sum_qk = tl.sum(exp_qk, axis=1) # [BH] | |
| # alpha = tl.math.exp2(max_log - new_max) # [BH] | |
| alpha = tl.math.exp(max_log - new_max) # [BH] | |
| sum_exp = sum_exp * alpha + sum_qk # [BH] | |
| accum = accum * alpha[:, None] # [BH, DP] | |
| accum = tl.dot(exp_qk, v_blck, accum) # [BH, DP] | |
| max_log = new_max | |
| # epilog | |
| # fin_log = max_log + tl.math.log2(sum_exp) # [BH] | |
| # fin_log *= 0.69314718 | |
| fin_log = max_log + tl.math.log(sum_exp) # [BH] | |
| out_vals = accum / sum_exp[:, None] # [BH, DP] | |
| o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od | |
| tl.store(o_ptrs, out_vals, mask=mask_h[:, None] & mask_d[None, :]) | |
| l_ptrs = l_base + offs_h * stride_lh | |
| tl.store(l_ptrs, fin_log, mask=mask_h) | |
| def _sel_attn_bwd_preprocess_kernel( | |
| Out, | |
| DOut, | |
| Delta, | |
| stride_ob, stride_oh, stride_om, stride_od, | |
| stride_dob, stride_doh, stride_dom, stride_dod, | |
| stride_db, stride_dh, stride_dm, | |
| B: tl.constexpr, | |
| H: tl.constexpr, | |
| M: tl.constexpr, | |
| D: tl.constexpr, | |
| DP: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| ): | |
| # program indices | |
| m = tl.program_id(0) | |
| bh = tl.program_id(1) | |
| b = bh // H | |
| h = bh % H | |
| # Base pointers | |
| o_base = Out + b * stride_ob + h * stride_oh | |
| do_base = DOut + b * stride_dob + h * stride_doh | |
| # Offsets | |
| offs_m = m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offs_d = tl.arange(0, DP) | |
| o_ptrs = o_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od | |
| do_ptrs = do_base + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod | |
| mask = (offs_m[:, None] < M) & (offs_d[None, :] < D) | |
| o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32) | |
| do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32) | |
| delta = tl.sum(o * do, axis=1) | |
| delta_ptr = Delta + b * stride_db + h * stride_dh + offs_m * stride_dm | |
| tl.store(delta_ptr, delta, mask=offs_m < M) | |
| def _sel_attn_bwd_kernel( | |
| Q: tl.tensor, | |
| K: tl.tensor, | |
| V: tl.tensor, | |
| Top_idx: tl.tensor, | |
| Lse: tl.tensor, | |
| DOut: tl.tensor, | |
| Delta: tl.tensor, | |
| softmax_scale: tl.constexpr, | |
| causal: tl.constexpr, | |
| DQ: tl.tensor, | |
| DK: tl.tensor, | |
| DV: tl.tensor, | |
| stride_qb, stride_qh, stride_qm, stride_qd, | |
| stride_kb, stride_kg, stride_kn, stride_kd, | |
| stride_vb, stride_vg, stride_vn, stride_vd, | |
| stride_tb, stride_tg, stride_tm, stride_tt, | |
| stride_ob, stride_oh, stride_om, stride_od, | |
| stride_lb, stride_lh, stride_lm, | |
| B: tl.constexpr, | |
| H: tl.constexpr, | |
| M: tl.constexpr, | |
| N: tl.constexpr, | |
| D: tl.constexpr, | |
| T: tl.constexpr, | |
| DP: tl.constexpr, | |
| SELECTION_BLOCK_SIZE: tl.constexpr, | |
| HEADS_PER_GROUP: tl.constexpr, | |
| OFFSET_M: tl.constexpr, | |
| BLOCK_H: tl.constexpr, | |
| ): | |
| # NOTE: Should we move this outside? | |
| stride_hg = stride_qh * HEADS_PER_GROUP | |
| b = tl.program_id(0) | |
| m = tl.program_id(1) + OFFSET_M | |
| g = tl.program_id(2) | |
| # Base pointers | |
| q_base = Q + b * stride_qb + m * stride_qm + g * stride_hg | |
| k_base = K + b * stride_kb + g * stride_kg | |
| v_base = V + b * stride_vb + g * stride_vg | |
| t_base = Top_idx + b * stride_tb + m * stride_tm + g * stride_tg | |
| l_base = Lse + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP | |
| do_base = DOut + b * stride_ob + m * stride_om + g * stride_hg | |
| d_base = Delta + b * stride_lb + m * stride_lm + g * stride_lh * HEADS_PER_GROUP | |
| dq_base = DQ + b * stride_qb + m * stride_qm + g * stride_hg | |
| dk_base = DK + b * stride_kb + g * stride_kg | |
| dv_base = DV + b * stride_vb + g * stride_vg | |
| # Offsets | |
| offs_h = tl.arange(0, BLOCK_H) | |
| mask_h = offs_h < HEADS_PER_GROUP | |
| offs_d = tl.arange(0, DP) | |
| mask_d = offs_d < D | |
| offs_n = tl.arange(0, SELECTION_BLOCK_SIZE) | |
| q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd | |
| q_blck = tl.load(q_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0).to(tl.float32) # [BH, DP] | |
| do_ptrs = do_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od | |
| do_blck = tl.load(do_ptrs, mask=mask_h[:, None] & mask_d[None, :], other=0.0).to(tl.float32) # [BH, DP] | |
| l_ptrs = l_base + offs_h * stride_lh | |
| l_blck = tl.load(l_ptrs, mask=mask_h, other=0.0) # [BH] | |
| d_ptrs = d_base + offs_h * stride_lh | |
| d_blck = tl.load(d_ptrs, mask=mask_h, other=0.0) # [BH] | |
| accum = tl.zeros([BLOCK_H, DP], dtype=tl.float32) | |
| # 1/ln(2) = 1.44269504 | |
| log_scale = softmax_scale * 1.44269504 | |
| max_col = max(0, N - M + m) if causal else N | |
| for idx in range(T): | |
| # NOTE: Ideally we load top_idx outside the loop, this can be done with a gather which will | |
| # supported in future versions of Triton | |
| top = tl.load(t_base + idx * stride_tt) | |
| col = top * SELECTION_BLOCK_SIZE | |
| col = tl.multiple_of(col, SELECTION_BLOCK_SIZE) | |
| if not causal or col <= max_col: | |
| cols = col + offs_n | |
| mask_n = cols < N | |
| k_ptrs = k_base + cols[None, :] * stride_kn + offs_d[:, None] * stride_kd | |
| k_blck = tl.load(k_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0).to(tl.float32) # [DP, BN] | |
| qk = tl.dot(q_blck, k_blck) * log_scale | |
| causal_mask = cols <= max_col | |
| qk = tl.where(causal_mask[None, :], qk, -1e6) | |
| l2 = l_blck * 1.44269504 | |
| exp_qk = tl.math.exp2(qk - l2[:, None]) # [BH, BN] | |
| dv_inc = tl.dot(tl.trans(exp_qk), do_blck) # [BN, DP] | |
| dv_ptrs = dv_base + cols[:, None] * stride_vn + offs_d[None, :] * stride_vd | |
| # [BN, DP] | |
| tl.atomic_add(dv_ptrs, dv_inc.to(tl.float32), mask=mask_d[None, :] & mask_n[:, None], sem="release", scope="gpu") | |
| v_ptrs = v_base + cols[None, :] * stride_vn + offs_d[:, None] * stride_vd | |
| v_blck = tl.load(v_ptrs, mask=mask_d[:, None] & mask_n[None, :], other=0.0).to(tl.float32) # [D, BN] | |
| dp = tl.dot(do_blck, v_blck) # [BH, BN] | |
| ds2 = exp_qk * (dp - d_blck[:, None]) # [BH, BN] | |
| ds = ds2 * softmax_scale | |
| accum = tl.dot(ds, tl.trans(k_blck), acc=accum) # [BH, DP] | |
| dk_inc = tl.dot(tl.trans(ds), q_blck) # [BN, DP] | |
| dk_ptrs = dk_base + cols[:, None] * stride_kn + offs_d[None, :] * stride_kd | |
| tl.atomic_add(dk_ptrs, dk_inc.to(tl.float32), mask=mask_d[None, :] & mask_n[:, None], sem="release", scope="gpu") | |
| dq_ptrs = dq_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd | |
| tl.store(dq_ptrs, accum, mask=mask_h[:, None] & mask_d[None, :]) | |
| class SelectionAttention(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| q, k, v, top_idx, | |
| selection_block_size, | |
| softmax_scale=None, | |
| causal=False, | |
| return_attn_probs=False | |
| ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: | |
| B, M, H, D = q.shape | |
| _, N, G, _ = k.shape | |
| _, _, _, T = top_idx.shape | |
| assert q.shape == (B, M, H, D) | |
| assert k.shape == (B, N, G, D) | |
| assert v.shape == (B, N, G, D) | |
| assert top_idx.shape == (B, M, G, T) | |
| if softmax_scale is None: | |
| softmax_scale = 1.0 / (D ** 0.5) | |
| # NOTE: Is it faster to only set the untouched elements? | |
| out = torch.zeros_like(q) | |
| lse = torch.full((B, H, M), float('-inf'), device=q.device, dtype=torch.float32) | |
| DP = triton.next_power_of_2(D) | |
| HEADS_PER_GROUP = H // G | |
| OFFSET_M = max(0, M - N) if causal else 0 | |
| BLOCK_H = max(16, HEADS_PER_GROUP) | |
| grid = (B, M - OFFSET_M, G) | |
| _sel_attn_fwd_kernel[grid]( | |
| q, k, v, top_idx, | |
| softmax_scale, causal, | |
| out, lse, | |
| q.stride(0), q.stride(2), q.stride(1), q.stride(3), | |
| k.stride(0), k.stride(2), k.stride(1), k.stride(3), | |
| v.stride(0), v.stride(2), v.stride(1), v.stride(3), | |
| top_idx.stride(0), top_idx.stride(2), top_idx.stride(1), top_idx.stride(3), | |
| out.stride(0), out.stride(2), out.stride(1), out.stride(3), | |
| lse.stride(0), lse.stride(1), lse.stride(2), | |
| B, H, M, N, D, T, DP, | |
| SELECTION_BLOCK_SIZE=selection_block_size, | |
| HEADS_PER_GROUP=HEADS_PER_GROUP, | |
| OFFSET_M=OFFSET_M, | |
| BLOCK_H=BLOCK_H, | |
| ) | |
| ctx.save_for_backward(q, k, v, top_idx, out, lse) | |
| ctx.selection_block_size = selection_block_size | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| if return_attn_probs: | |
| return out, lse | |
| else: | |
| return out | |
| def backward(ctx, *grad_outputs): | |
| d_out = grad_outputs[0] | |
| q, k, v, top_idx, out, lse = ctx.saved_tensors | |
| B, M, H, D = q.shape | |
| _, N, G, _ = k.shape | |
| _, _, _, T = top_idx.shape | |
| assert d_out.shape == (B, M, H, D) | |
| selection_block_size = ctx.selection_block_size | |
| softmax_scale = ctx.softmax_scale | |
| causal = ctx.causal | |
| delta = torch.empty_like(lse) | |
| DP = triton.next_power_of_2(D) | |
| HEADS_PER_GROUP = H // G | |
| OFFSET_M = max(0, M - N) if causal else 0 | |
| BLOCK_H = max(16, HEADS_PER_GROUP) | |
| def grid_preprocess(META): | |
| return (triton.cdiv(M, META['BLOCK_M']), B * H) | |
| _sel_attn_bwd_preprocess_kernel[grid_preprocess]( | |
| out, d_out, delta, | |
| out.stride(0), out.stride(2), out.stride(1), out.stride(3), | |
| d_out.stride(0), d_out.stride(2), d_out.stride(1), d_out.stride(3), | |
| delta.stride(0), delta.stride(1), delta.stride(2), | |
| B, H, M, D, DP, | |
| ) | |
| dq = torch.empty_like(q, dtype=q.dtype) | |
| dk = torch.zeros_like(k, dtype=torch.float32) | |
| dv = torch.zeros_like(v, dtype=torch.float32) | |
| grid_bwd = (B, M - OFFSET_M, G) | |
| _sel_attn_bwd_kernel[grid_bwd]( | |
| q, k, v, top_idx, lse, | |
| d_out, delta, | |
| softmax_scale, causal, | |
| dq, dk, dv, | |
| q.stride(0), q.stride(2), q.stride(1), q.stride(3), | |
| k.stride(0), k.stride(2), k.stride(1), k.stride(3), | |
| v.stride(0), v.stride(2), v.stride(1), v.stride(3), | |
| top_idx.stride(0), top_idx.stride(2), top_idx.stride(1), top_idx.stride(3), | |
| out.stride(0), out.stride(2), out.stride(1), out.stride(3), | |
| lse.stride(0), lse.stride(1), lse.stride(2), | |
| B, H, M, N, D, T, DP, | |
| SELECTION_BLOCK_SIZE=selection_block_size, | |
| HEADS_PER_GROUP=HEADS_PER_GROUP, | |
| OFFSET_M=OFFSET_M, | |
| BLOCK_H=BLOCK_H, | |
| ) | |
| return dq, dk.to(k.dtype), dv.to(v.dtype), None, None, None, None, None | |
| def selection_attention( | |
| q, k, v, | |
| block_indices, block_count, block_size, scale, | |
| variant='two-pass', # 'one-pass' or 'two-pass' | |
| causal=True, | |
| return_attn_probs=False | |
| ): | |
| if variant == 'one-pass': | |
| return SelectionAttention.apply( | |
| q, k, v, block_indices, block_size, scale, causal, return_attn_probs | |
| ) | |
| elif variant == 'two-pass': | |
| # FLA Backend for two-pass selection attention | |
| return ParallelNSAFunction.apply( | |
| q, k, v, block_indices, block_count, block_size, scale, None | |
| ) | |
| else: | |
| raise ValueError(f"Invalid variant: {variant}") | |
| def nsa_func( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| g_cmp: Optional[torch.Tensor] = None, | |
| g_slc: Optional[torch.Tensor] = None, | |
| g_swa: Optional[torch.Tensor] = None, | |
| block_count: int = 16, | |
| block_size: int = 64, | |
| window_size: int = 0, | |
| scale: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| B, M, H, D = q.shape | |
| _, N, G, _ = k.shape | |
| assert g_cmp is not None and g_slc is not None and g_swa is not None, "g_cmp, g_slc, and g_swa are required" | |
| assert k.shape == (B, N, G, D), f"k shape: {k.shape} must be ({B}, {N}, {G}, {D})" | |
| assert v.shape == (B, N, G, D), f"v shape: {v.shape} must be ({B}, {N}, {G}, {D})" | |
| assert g_cmp.shape == (B, M, H), f"g_cmp shape: {g_cmp.shape} must be ({B}, {M}, {H})" | |
| assert g_slc.shape == (B, M, H), f"g_slc shape: {g_slc.shape} must be ({B}, {M}, {H})" | |
| assert g_swa.shape == (B, M, H), f"g_swa shape: {g_swa.shape} must be ({B}, {M}, {H})" | |
| if scale is None: | |
| scale = D ** -0.5 | |
| k_cmp, v_cmp = mean_pooling(k, block_size), mean_pooling(v, block_size) | |
| def cmp_mask(b, h, q_idx, kv_idx): | |
| return q_idx <= (kv_idx + 1) * block_size - 1 | |
| block_mask = create_block_mask(cmp_mask, B, H, M, N//block_size) | |
| o_cmp, lse_cmp = compression_attention(q, k_cmp, v_cmp, block_mask) | |
| block_indices = parallel_nsa_topk( | |
| q=q, | |
| k=k_cmp, | |
| lse=lse_cmp, | |
| block_counts=block_count, | |
| block_size=block_size, | |
| scale=scale, | |
| cu_seqlens=None | |
| ) | |
| o_slc = selection_attention( | |
| q, k, v, block_indices, block_count, block_size, scale | |
| ) | |
| o_swd = flash_attn_func( | |
| q, k, v, | |
| causal=True, | |
| window_size=(window_size-1, 0) | |
| ) | |
| o = o_cmp * g_cmp.unsqueeze(-1) + o_slc * g_slc.unsqueeze(-1) + o_swd * g_swa.unsqueeze(-1) | |
| return o |