# Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Optional import torch import triton import triton.language as tl try: from .utils import get_num_warps_stages, is_hopper_gpu except ImportError: from ops.utils import get_num_warps_stages, is_hopper_gpu IS_HOPPER_GPU = is_hopper_gpu() @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'], ) @triton.jit def forward_kernel_orig( q_ptr, # Q: n x h x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d t_ptr, # topk_idx: kh x n x k o_ptr, # O: n x h x d lse_ptr, # LSE: h x n # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, block_size, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_th, stride_tn, stride_tk, stride_on, stride_oh, stride_od, stride_lh, stride_ln, # META parameters # q loop num num_q_loop: tl.constexpr, num_k_loop: tl.constexpr, MAX_SEQ_LEN: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid = tl.program_id(0) Q = MAX_SEQ_LEN // num_q_loop HK = NUM_KV_HEADS // num_k_loop # 第几个 (b, kh_chunk, q_chunk) pid_b = pid // (HK * Q) pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head pid_q = pid % Q # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if pid_q * num_q_loop >= q_len: return real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) for kh_offset in range(num_k_loop): pid_kh = pid_kh_chunk * num_k_loop + kh_offset pid_h = pid_kh * NUM_SHARE_Q_HEADS for j in range(real_q_loop): pid_q_j = pid_q * num_q_loop + j # init topk idx pointer off_t = tl.arange(0, BLOCK_SIZE_T) t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) """Removed causal attention, which should be: real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), axis=0, ) """ # real_topk = tl.sum( # tl.where((topk_idx >= 0), 1, 0), # axis=0, # ) real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), axis=0, ) # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_qh, stride_qd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(HEAD_DIM, k_len), strides=(stride_kd, stride_kn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_h = tl.arange(0, BLOCK_SIZE_H) off_k = tl.arange(0, BLOCK_SIZE_K) m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) # sparse attention for i in range(real_topk): # get current block start index c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K t_ptr_j = t_ptr_j + stride_tk # load k k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] qk += tl.dot(q, k) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") p = p.to(v.dtype) acc_o += tl.dot(p, v) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) # final scale acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_oh, stride_od), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) # save lse lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'], ) @triton.jit def backward_sum_o_do( o_ptr, # O: n x h x d do_ptr, # dO: n x h x d delta_ptr, # D: h x n o_len, HEAD_DIM, stride_on, stride_oh, stride_od, stride_don, stride_doh, stride_dod, stride_dh, stride_dn, BLOCK_SIZE_O: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_n = tl.program_id(0) pid_h = tl.program_id(1) off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) off_d = tl.arange(0, BLOCK_SIZE_D) o = tl.load( o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) do = tl.load( do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len) @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['BLOCK_SIZE_N', 'BLOCK_SIZE_K', 'BLOCK_SIZE_R'], ) @triton.jit def count_kernel( x_ptr, # [num_kv_heads, total_len, topk] y_ptr, # [num_kv_heads, total_blocks] cu_seqlens, # [batch_size + 1] cu_seqblocks, # [batch_size + 1] topk, stride_xh, stride_xn, stride_xk, stride_yh, stride_yn, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_R: tl.constexpr, ): pid_h = tl.program_id(0) pid_b = tl.program_id(1) # get start and len after rmpad seq_start = tl.load(cu_seqlens + pid_b) seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start blocks_start = tl.load(cu_seqblocks + pid_b) num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start # load x off_k = tl.arange(0, BLOCK_SIZE_K) off_n = tl.arange(0, BLOCK_SIZE_N) x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk # init y y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) # loop for i in range(0, seq_len, BLOCK_SIZE_N): x = tl.load( x_ptrs, mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], other=-1, ) x = tl.ravel(x) y += tl.histogram(x, BLOCK_SIZE_R) x_ptrs += BLOCK_SIZE_N * stride_xn # store result off_r = tl.arange(0, BLOCK_SIZE_R) y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn y_ptrs = y_ptr + off_r * stride_yn tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) def count_query( topk_idx: torch.Tensor, cu_seqlens: torch.Tensor, cu_seqblocks: torch.Tensor, block_size: int, ): num_kv_heads, total_len, topk = topk_idx.shape seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] batch_size = seqlens.shape[0] BLOCK_SIZE_K = triton.next_power_of_2(topk) BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device) grid = (num_kv_heads, batch_size) count_kernel[grid]( topk_idx, active_query_count, cu_seqlens, cu_seqblocks, topk, topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), active_query_count.stride(0), active_query_count.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_R=BLOCK_SIZE_R, # num_warps=4, # num_stages=3, ) return active_query_count @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['topk', 'BLOCK_SIZE_N', 'BLOCK_SIZE_T'], ) @triton.jit def pad_topk_idx_kernel( t_ptr, p_ptr, cu_seqlens, topk, stride_th, stride_tn, stride_tk, stride_pb, stride_ph, stride_pn, stride_pk, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_n = tl.program_id(2) # get q start and len after rmpad q_start = tl.load(cu_seqlens + pid_b) q_len = tl.load(cu_seqlens + pid_b + 1) - q_start if BLOCK_SIZE_N * pid_n >= q_len: return # init prts t_ptrs = tl.make_block_ptr( base=t_ptr + pid_h * stride_th + q_start * stride_tn, shape=(q_len, topk), strides=(stride_tn, stride_tk), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), order=(1, 0), ) p_ptrs = tl.make_block_ptr( base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, shape=(q_len, topk), strides=(stride_pn, stride_pk), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), order=(1, 0), ) # load and save idxs = tl.load(t_ptrs, boundary_check=(0, 1)) tl.store(p_ptrs, idxs, boundary_check=(0, 1)) @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['BLOCK_SIZE_N'], ) @triton.jit def save_topk_idx_kernel( p_ptr, t_ptr, cu_seqblocks, cu_topk_q_count, n_len, stride_pb, stride_ph, stride_pn, stride_th, stride_tn, stride_ch, stride_cn, BLOCK_SIZE_N: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_n = tl.program_id(2) # get q start and len after rmpad q_block_start = tl.load(cu_seqblocks + pid_b) q_block_end = tl.load(cu_seqblocks + pid_b + 1) c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) c_len = c_end - c_start if c_len <= 0: return if pid_n * BLOCK_SIZE_N >= c_len: return # init ptrs p_ptrs = tl.make_block_ptr( base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, shape=(c_len,), strides=(stride_pn,), offsets=(pid_n * BLOCK_SIZE_N,), block_shape=(BLOCK_SIZE_N,), order=(0,), ) t_ptrs = tl.make_block_ptr( base=t_ptr + pid_h * stride_th + c_start * stride_tn, shape=(c_len,), strides=(stride_tn,), offsets=(pid_n * BLOCK_SIZE_N,), block_shape=(BLOCK_SIZE_N,), order=(0,), ) # load and save idxs = tl.load(p_ptrs, boundary_check=(0,)) tl.store(t_ptrs, idxs, boundary_check=(0,)) def reorder_topk_idx( topk_idx: torch.Tensor, cu_topk_q_count: torch.Tensor, cu_seqlens: torch.Tensor, cu_seqblocks: torch.Tensor, block_size: int, ): num_kv_heads, total_len, topk = topk_idx.shape batch_size = cu_seqlens.shape[0] - 1 seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] pad_topk_idx = torch.full( (batch_size, num_kv_heads, max_seqlen, topk), fill_value=-1, device=topk_idx.device, dtype=torch.int32, ) BLOCK_SIZE_T = triton.next_power_of_2(topk) BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)) grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) pad_topk_idx_kernel[grid]( topk_idx, pad_topk_idx, cu_seqlens, topk, topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), pad_topk_idx.stride(0), pad_topk_idx.stride(1), pad_topk_idx.stride(2), pad_topk_idx.stride(3), BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_T=BLOCK_SIZE_T, ) # argsort pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) # save as remove pad version topk_q_idx = torch.full( (num_kv_heads, cu_topk_q_count[:, -1].max().item()), fill_value=-1, device=topk_idx.device, dtype=torch.int32, ) max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item() BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) save_topk_idx_kernel[grid]( pad_topk_q_idx, topk_q_idx, cu_seqblocks, cu_topk_q_count, pad_topk_q_idx.shape[-1], pad_topk_q_idx.stride(0), pad_topk_q_idx.stride(1), pad_topk_q_idx.stride(2), topk_q_idx.stride(0), topk_q_idx.stride(1), cu_topk_q_count.stride(0), cu_topk_q_count.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, ) return topk_q_idx @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'], ) @triton.jit def backward_dkdv( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d tq_ptr, # topk_q_idx: kh x N lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dk_ptr, # DK: sh x n x kh x d dv_ptr, # DK: sh x n x kh x d # seqlens cu_seqlens_q, # [batch_size + 1] cu_seqlens_k, # [batch_size + 1] cu_seqblocks, # [batch_size + 1] cu_topk_q_count, # [kh, total_blocks] # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_tqh, stride_tqn, stride_ctqh, stride_ctqn, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dks, stride_dkn, stride_dkh, stride_dkd, stride_dvs, stride_dvn, stride_dvh, stride_dvd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_kh = pid_h // NUM_SHARE_Q_HEADS pid_sh = pid_h % NUM_SHARE_Q_HEADS pid_k = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_K * pid_k >= k_len: return # get topk_q_idx b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) act_q_len = act_q_end - act_q_start tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn # init pointers k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dk_ptrs = tl.make_block_ptr( base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, shape=(k_len, HEAD_DIM), strides=(stride_dkn, stride_dkd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dv_ptrs = tl.make_block_ptr( base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, shape=(k_len, HEAD_DIM), strides=(stride_dvn, stride_dvd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K off_d = tl.arange(0, BLOCK_SIZE_D) # load k v and keep in SRAM k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # init dk dv dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) # init ptrs q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh # loop for q blocks for i in range(0, act_q_len, BLOCK_SIZE_Q): # load idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) q = tl.load( q_ptrs + idx_q[:, None] * stride_qn, mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], other=0, ) do = tl.load( do_ptrs + idx_q[:, None] * stride_don, mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], other=0, ) lse = tl.load( lse_ptrs + idx_q[:, None] * stride_ln, mask=(off_q < act_q_len - i)[:, None], other=0, ) d = tl.load( d_ptrs + idx_q[:, None] * stride_dn, mask=(off_q < act_q_len - i)[:, None], other=0, ) # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) qk += tl.dot(q, k.T) * qk_scale # compute p, ds p = tl.exp2(qk - lse) dp = tl.dot(do, v.T) ds = sm_scale * p * (dp - d) # cast dtype p = p.to(do.dtype) ds = ds.to(q.dtype) # update dk and dv dk += tl.dot(ds.T, q) dv += tl.dot(p.T, do) # save dk dv tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]], key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'], ) @triton.jit def backward_dq( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d t_ptr, # topk_idx: kh x n x k lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dq_ptr, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # q loop num num_q_loop, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_th, stride_tn, stride_tk, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dqn, stride_dqh, stride_dqd, # META parameters BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_kh = tl.program_id(1) pid_q = tl.program_id(2) pid_h = pid_kh * NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if pid_q * num_q_loop >= q_len: return real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) for j in range(real_q_loop): pid_q_j = pid_q * num_q_loop + j # init topk idx pointer off_t = tl.arange(0, BLOCK_SIZE_T) t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), axis=0, ) # init pointers q_ptrs = tl.make_block_ptr( base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_qh, stride_qd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) dq_ptrs = tl.make_block_ptr( base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_dqh, stride_dqd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(HEAD_DIM, k_len), strides=(stride_vd, stride_vn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) do_ptrs = tl.make_block_ptr( base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_doh, stride_dod), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) d_ptrs = tl.make_block_ptr( base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, shape=(NUM_SHARE_Q_HEADS, 1), strides=(stride_dh, stride_dn), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, 1), order=(1, 0), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, shape=(NUM_SHARE_Q_HEADS, 1), strides=(stride_lh, stride_ln), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, 1), order=(1, 0), ) # offsets off_k = tl.arange(0, BLOCK_SIZE_K) # load q, do, lse, delta, and keep in SRAM q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # init dq dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) # sparse for i in range(real_topk): # get current block start index c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K t_ptr_j = t_ptr_j + stride_tk # load k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] qk += tl.dot(q, tl.trans(k)) * qk_scale # compute p, ds p = tl.exp2(qk - lse) dp = tl.dot(do, v) ds = sm_scale * p * (dp - d) # cast dtype ds = ds.to(q.dtype) # update dq dq += tl.dot(ds, k) # save dq tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) def _topk_sparse_attention_fwd( q: torch.Tensor, # [total_len, num_q_heads, head_dim] k: torch.Tensor, # [total_len, num_k_heads, head_dim] v: torch.Tensor, # [total_len, num_k_heads, head_dim] topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): # dtype check assert k.dtype == q.dtype and v.dtype == q.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 assert block_size in {32, 64, 128, 256} # shape q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape batch_size = cu_seqlens_q.shape[0] - 1 # assert q_len == k_len and k_len == v_len topk = topk_idx.shape[-1] assert topk_idx.shape[0] == num_k_heads assert topk_idx.shape[1] == q_len # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # output tensor o = torch.zeros_like(q) lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) # launch kernel num_q_loop = num_k_loop = 1 BLOCK_SIZE_K = triton.next_power_of_2(block_size) BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) BLOCK_SIZE_T = triton.next_power_of_2(topk) def grid(meta): grid = ( batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop), ) return grid num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU) forward_kernel_orig[grid]( q, k, v, topk_idx, o, lse, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, topk, block_size, # num_q_loop, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), o.stride(0), o.stride(1), o.stride(2), lse.stride(0), lse.stride(1), num_q_loop=num_q_loop, num_k_loop=num_k_loop, MAX_SEQ_LEN=max_seqlen_q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_T=BLOCK_SIZE_T, # num_warps=num_warps, # num_stages=num_stages, ) return o, lse def _topk_sparse_attention_bwd( o: torch.Tensor, do: torch.Tensor, lse: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): assert block_size in {32, 64, 128, 256} q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape o_len, num_o_heads, head_dim = o.shape num_share_q_heads = num_q_heads // num_k_heads topk = topk_idx.shape[-1] # compute D delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) BLOCK_SIZE_O = 256 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) backward_sum_o_do[grid]( o, do, delta, o_len, head_dim, o.stride(0), o.stride(1), o.stride(2), do.stride(0), do.stride(1), do.stride(2), delta.stride(0), delta.stride(1), BLOCK_SIZE_O=BLOCK_SIZE_O, BLOCK_SIZE_D=BLOCK_SIZE_D, # num_warps=num_warps, # num_stages=num_stages, ) # count active querys for each key block, shape: (num_k_heads, total_k_blocks) seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) cu_seqblocks = torch.cat( [ torch.zeros(1, dtype=torch.int32, device=topk_idx.device), torch.cumsum(seqblocks, dim=0), ] ).to(torch.int32) topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) cu_topk_q_count = torch.cat( [ torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), torch.cumsum(topk_q_count, dim=-1), ], dim=-1, ).to(torch.int32) # active query idx for each key block # how to get active query idx for sequence b, head h, kv block i? topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) # compute dk dv dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) batch_size = cu_seqlens_q.shape[0] - 1 BLOCK_SIZE_K = triton.next_power_of_2(block_size) BLOCK_SIZE_Q = 64 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) backward_dkdv[grid]( q, k, v, topk_q_idx, lse, delta, do, dk, dv, cu_seqlens_q, cu_seqlens_k, cu_seqblocks, cu_topk_q_count, num_k_heads, num_share_q_heads, head_dim, topk, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_q_idx.stride(0), topk_q_idx.stride(1), cu_topk_q_count.stride(0), cu_topk_q_count.stride(1), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, # num_warps=num_warps, # num_stages=num_stages, ) dk = dk.sum(0) dv = dv.sum(0) # compute dq dq = torch.zeros_like(q) num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) BLOCK_SIZE_K = block_size BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) BLOCK_SIZE_T = triton.next_power_of_2(topk) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) backward_dq[grid]( q, k, v, topk_idx, lse, delta, do, dq, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, topk, num_q_loop, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dq.stride(0), dq.stride(1), dq.stride(2), BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_T=BLOCK_SIZE_T, # num_warps=num_warps, # num_stages=num_stages, ) return dq, dk, dv class TopkSparseAttention(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, # [total_len, num_q_heads, head_dim] k: torch.Tensor, # [total_len, num_k_heads, head_dim] v: torch.Tensor, # [total_len, num_k_heads, head_dim] topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: torch.Tensor, max_seqlen_k: torch.Tensor, sm_scale=None, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert q.dtype == k.dtype and k.dtype == v.dtype assert topk_idx.dtype == torch.int32 assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 # softmax scale if sm_scale is None: sm_scale = 1 / math.sqrt(q.shape[-1]) o, lse = _topk_sparse_attention_fwd( q, k, v, topk_idx, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) ctx.sm_scale = sm_scale ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.block_size = block_size return o @staticmethod def backward(ctx, do: torch.Tensor, *args) -> Any: q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors max_seqlen_q = ctx.max_seqlen_q max_seqlen_k = ctx.max_seqlen_k sm_scale = ctx.sm_scale block_size = ctx.block_size assert block_size in {32, 64, 128, 256} dq, dk, dv = _topk_sparse_attention_bwd( o, do, lse, q, k, v, topk_idx, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) return dq, dk, dv, None, None, None, None, None, None, None, None def topk_sparse_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size: int, cu_seqlens: torch.Tensor, softmax_scale: Optional[float] = None, ) -> torch.Tensor: """Topk sparse attention varlen version implemented in triton. Args: q (torch.Tensor): shape [total_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. block_size (int): key value block size. cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). Returns: torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] """ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return TopkSparseAttention.apply( q, k, v, topk_idx, block_size, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, softmax_scale, )