| import os |
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| import math |
|
|
| from .common import _attn_fwd_gating, _attn_bwd_preprocess, configs_gating_preset |
| from .flash_attn_bsa_varlen_mask import ( |
| _attn_fwd_bsa_varlen, _attn_fwd_bsa_varlen_align, _attn_bwd_dkdv_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_align_wrapper, |
| configs_fwd_bsa_varlen_preset, configs_fwd_bsa_varlen_align_preset, configs_bwd_dkdv_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_align_preset |
| ) |
|
|
| from .communicate import p2p_communicate |
|
|
| torch._dynamo.config.cache_size_limit = 32 |
|
|
| def is_cuda(): |
| return triton.runtime.driver.active.get_current_target().backend == "cuda" |
|
|
| def supports_tma(): |
| return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 |
|
|
| HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) |
|
|
| if HAS_TMA_DESC: |
| print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) |
| else: |
| print("TMA benchmarks will be running without grid constant TMA descriptor.", ) |
|
|
|
|
| |
| class TmaAutoTuneHelper: |
|
|
| |
| class KernelParamWrapper: |
|
|
| def __init__(self, desc): |
| self.desc = desc |
|
|
| def tma_desc_cpu_ptr(self): |
| return self.desc.data_ptr() |
|
|
| TMA_SIZE = 128 |
|
|
| def __init__(self): |
| self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) |
| self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) |
| if HAS_TMA_DESC: |
| self.descriptors = {} |
| else: |
| self.cuda_descriptors = {} |
|
|
| |
| def init_tma_descriptor(self, name): |
| if HAS_TMA_DESC: |
| self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) |
| else: |
| self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) |
|
|
| |
| def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): |
| if HAS_TMA_DESC: |
| desc_x = self.descriptors[name] |
| assert desc_x.data_ptr() % 64 == 0 |
| self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) |
| else: |
| desc_x = self.cuda_descriptors[name] |
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
| self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) |
| desc_x.copy_(buf_x, non_blocking=True) |
|
|
| |
| def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): |
| if HAS_TMA_DESC: |
| desc_x = self.descriptors[name] |
| assert desc_x.data_ptr() % 64 == 0 |
| self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) |
| else: |
| desc_x = self.cuda_descriptors[name] |
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
| self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) |
| desc_x.copy_(buf_x, non_blocking=True) |
|
|
| def get_tma_descriptor_kernel_param(self, name): |
| if HAS_TMA_DESC: |
| assert self.descriptors[name] is not None |
| return self.KernelParamWrapper(self.descriptors[name]) |
| else: |
| assert self.cuda_descriptors[name] is not None |
| return self.cuda_descriptors[name] |
|
|
|
|
| @triton.jit |
| def create_mask_from_indices_kernel( |
| block_indices, |
| block_mask, |
| stride_bz, stride_bh, stride_bm, stride_bs, |
| stride_mz, stride_mh, stride_mm, stride_mn, |
| H, |
| ): |
| i_zh, i_m, i_s = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| i_z, i_h = i_zh // H, i_zh % H |
|
|
| off_b = i_z.to(tl.int64) * stride_bz + i_h.to(tl.int64) * stride_bh + i_m.to(tl.int64) * stride_bm + i_s.to(tl.int64) * stride_bs |
| |
| b_i = tl.load(block_indices + off_b) |
| |
| off_m = i_z.to(tl.int64) * stride_mz + i_h.to(tl.int64) * stride_mh + i_m.to(tl.int64) * stride_mm + b_i.to(tl.int64) * stride_mn |
| |
| b_m = 1 |
| tl.store(block_mask + off_m, b_m.to(block_mask.dtype.element_ty)) |
|
|
| def create_mask_from_indices_triton( |
| block_indices, |
| N_cols |
| ): |
| B, H, N_rows, S = block_indices.shape |
| block_mask = torch.zeros((B, H, N_rows, N_cols), dtype=torch.bool, device=block_indices.device) |
| create_mask_from_indices_kernel[(B * H, N_rows, S)]( |
| block_indices, |
| block_mask, |
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
| block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), |
| H, |
| ) |
| return block_mask |
|
|
|
|
| def create_mask_from_indices_varlen(block_indices, N_cols_mask): |
| |
| B, H, M, _ = block_indices.shape |
| device = block_indices.device |
| |
| mask = torch.zeros((B, H, M, N_cols_mask), dtype=torch.bool, device=device) |
| |
| valid = block_indices < N_cols_mask |
| |
| b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(block_indices) |
| h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(block_indices) |
| m_idx = torch.arange(M, device=device)[None, None, :, None].expand_as(block_indices) |
| |
| valid_coords = (b_idx[valid], h_idx[valid], m_idx[valid], block_indices[valid]) |
| |
| mask[valid_coords] = True |
| |
| return mask |
|
|
|
|
| def create_indices_k_from_indices_q_varlen( |
| block_indices, |
| N_cols_mask |
| ): |
| block_mask_qk = create_mask_from_indices_varlen(block_indices, N_cols_mask) |
| B, H, M, N = block_mask_qk.shape |
| block_mask_kq = block_mask_qk.permute(0, 1, 3, 2) |
| indices = torch.arange(M, device=block_indices.device).view(1, 1, 1, -1).expand_as(block_mask_kq) |
| block_indices_k = torch.where(block_mask_kq, indices, M) |
| block_indices_k, _ = torch.sort(block_indices_k, dim=-1) |
| |
| block_indices_k_lens = (block_indices_k < M).sum(dim=-1) |
|
|
| return block_indices_k, block_indices_k_lens |
|
|
|
|
|
|
| def mean_pooling_compression( |
| x: torch.Tensor, |
| block_size: int |
| ) -> torch.Tensor: |
| B, H, S = x.shape[:3] |
| num_block = math.ceil(S / block_size) |
| if S % block_size != 0: |
| x = F.pad(x, (0, 0, 0, num_block * block_size - S)) |
| x_cmp = x.view(B, H, num_block, block_size, -1).mean(dim=3) |
| return x_cmp |
|
|
|
|
| def cal_score(q, k): |
| k_transposed = k.transpose(-1, -2) |
| score = torch.matmul(q, k_transposed) |
| return score |
|
|
| def cal_score_triton(q, k): |
| B, H, s_q, D = q.shape |
| s_k = k.shape[2] |
| |
| score = torch.empty(B, H, s_q, s_k, device=q.device, dtype=q.dtype) |
| |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_gating_preset['default'] |
| |
| grid = lambda args: (triton.cdiv(s_q, args["BLOCK_M"]), B * H, 1) |
| _attn_fwd_gating[grid]( |
| q, k, score, |
| q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| score.stride(0), score.stride(1), score.stride(2), score.stride(3), |
| H, s_q, s_k, |
| HEAD_DIM=D, |
| **kernel_config |
| ) |
| return score |
|
|
|
|
| def get_select_indices_topk(q, k, sparsity): |
| score = cal_score(q, k) |
| block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity) |
| return block_indices, block_indices_lens |
|
|
|
|
| def get_select_indices_topk_from_score(score, sparsity): |
| num_selected = int((1 - sparsity) * score.shape[-1]) |
| block_indices = torch.topk(score, num_selected)[1] |
|
|
| block_indices_lens = torch.full( |
| (block_indices.shape[0], block_indices.shape[1], block_indices.shape[2]), |
| num_selected, |
| dtype=torch.int32, |
| device=block_indices.device |
| ) |
|
|
| return block_indices, block_indices_lens |
|
|
|
|
| def get_select_indices_cdf(q, k, cdf_threshold): |
| score = cal_score(q, k) |
| head_dim = q.shape[-1] |
| block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold, 1 / head_dim**0.5) |
| return block_indices, block_indices_lens |
|
|
|
|
| def get_select_indices_cdf_from_score(score, cdf_threshold, sm_scale): |
| weights = torch.softmax(score * sm_scale, dim=-1) |
| |
| B, H, Sq, Sk = weights.shape |
| cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1) |
| weights_sorted = torch.sort(weights, dim=-1, descending=True) |
| cdf = torch.cumsum(weights_sorted.values, dim=-1) |
| num_selected = torch.searchsorted(cdf, cdf_threshold, right=True) |
| |
| return weights_sorted.indices, num_selected.squeeze(-1) |
|
|
|
|
| def get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold): |
| score = cal_score(q, k) |
| head_dim = q.shape[-1] |
| block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, 1 / head_dim**0.5) |
| return block_indices, block_indices_lens |
|
|
|
|
| def get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, sm_scale): |
| weights = torch.softmax(score * sm_scale, dim=-1) |
| |
| B, H, Sq, Sk = weights.shape |
| cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1) |
| weights_sorted = torch.sort(weights, dim=-1, descending=True) |
| cdf = torch.cumsum(weights_sorted.values, dim=-1) |
| num_selected = torch.searchsorted(cdf, cdf_threshold, right=True) |
|
|
| |
| num_selected_topk = int((1 - sparsity) * score.shape[-1]) |
| num_selected[num_selected < num_selected_topk] = num_selected_topk |
|
|
| return weights_sorted.indices, num_selected.squeeze(-1) |
|
|
| def get_select_indices(q, k, sparsity, cdf_threshold): |
| if sparsity is not None and cdf_threshold is None: |
| block_indices, block_indices_lens = get_select_indices_topk(q, k, sparsity) |
| elif sparsity is None and cdf_threshold is not None: |
| block_indices, block_indices_lens = get_select_indices_cdf(q, k, cdf_threshold) |
| elif sparsity is not None and cdf_threshold is not None: |
| block_indices, block_indices_lens = get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold) |
| else: |
| raise ValueError |
| return block_indices, block_indices_lens |
|
|
| def get_select_indices_from_score(score, sparsity, cdf_threshold): |
| if sparsity is not None and cdf_threshold is None: |
| block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity) |
| elif sparsity is None and cdf_threshold is not None: |
| block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold) |
| elif sparsity is not None and cdf_threshold is not None: |
| block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold) |
| else: |
| raise ValueError |
| return block_indices, block_indices_lens |
|
|
| def attn_fwd_bsa_varlen_triton( |
| q, |
| k, |
| v, |
| sm_scale, |
| block_indices, |
| block_indices_lens, |
| chunk_size_q, |
| chunk_size_k, |
| sparsity |
| ): |
| |
| B, H, Seq, D = q.shape |
|
|
| o = torch.empty_like(q) |
| M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
| |
| grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) |
|
|
| config_key = 'BLOCK_N_LG=64' if chunk_size_k == 64 else 'default' |
| if chunk_size_k > 128: |
| fwd_func = _attn_fwd_bsa_varlen |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_preset[config_key] |
| else: |
| fwd_func = _attn_fwd_bsa_varlen_align |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_align_preset[config_key] |
| |
| block_indices = block_indices.contiguous() |
| block_indices_lens = block_indices_lens.contiguous() |
| |
| fwd_func[grid]( |
| q, k, v, sm_scale, M, o, |
| block_indices, |
| block_indices_lens, |
| q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
| o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
| block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2), |
| H, Seq, |
| D, |
| BLOCK_M=chunk_size_q, |
| BLOCK_N_LG=chunk_size_k, |
| SPARSITY=sparsity, |
| **kernel_config |
| ) |
| |
| LN2 = 0.6931471824645996 |
| lse = M * LN2 |
|
|
| return o, lse |
|
|
| def attn_bwd_bsa_varlen_triton( |
| do, |
| q, |
| k, |
| v, |
| o, |
| dq, |
| dk, |
| dv, |
| sm_scale, |
| M, |
| block_indices, |
| block_indices_lens, |
| chunk_size_q, |
| chunk_size_k, |
| sparsity |
| ): |
| RCP_LN2 = 1.4426950408889634 |
| M = M * RCP_LN2 |
| |
| do = do.contiguous() |
| |
|
|
| BATCH, N_HEAD, N_CTX, HEAD_DIM = q.shape |
| N_CTX_KV = k.shape[-2] |
|
|
| RCP_LN2 = 1.4426950408889634 |
| arg_k = k |
| arg_k = arg_k * (sm_scale * RCP_LN2) |
| |
| if min(chunk_size_q, chunk_size_k) >= 128: |
| PRE_BLOCK = 128 |
| else: |
| PRE_BLOCK = min(chunk_size_q, chunk_size_k) |
| |
| assert N_CTX % PRE_BLOCK == 0 |
| pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) |
| delta = torch.empty_like(M) |
| _attn_bwd_preprocess[pre_grid]( |
| o, do, |
| delta, |
| N_CTX, |
| BLOCK_M=PRE_BLOCK, |
| HEAD_DIM=HEAD_DIM |
| ) |
|
|
| block_indices_k, block_indices_k_lens = create_indices_k_from_indices_q_varlen( |
| block_indices=block_indices, |
| N_cols_mask=N_CTX_KV // chunk_size_k |
| ) |
| |
| block_indices = block_indices.contiguous() |
| block_indices_lens = block_indices_lens.contiguous() |
| block_indices_k = block_indices_k.contiguous() |
| block_indices_k_lens = block_indices_k_lens.contiguous() |
|
|
| config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default' |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dkdv_bsa_varlen_preset[config_key] |
| |
| grid_dkdv = lambda args: (triton.cdiv(arg_k.shape[2], args["BLOCK_N"]), 1, arg_k.shape[0] * arg_k.shape[1]) |
| _attn_bwd_dkdv_bsa_varlen_wrapper[grid_dkdv]( |
| q, arg_k, v, sm_scale, |
| do, |
| dk, dv, |
| M, |
| delta, |
| block_indices_k, |
| block_indices_k_lens, |
| q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
| dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), |
| dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), |
| do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
| M.stride(0), M.stride(1), M.stride(2), |
| delta.stride(0), delta.stride(1), delta.stride(2), |
| block_indices_k.stride(0), block_indices_k.stride(1), block_indices_k.stride(2), block_indices_k.stride(3), |
| block_indices_k_lens.stride(0), block_indices_k_lens.stride(1), block_indices_k_lens.stride(2), |
| N_HEAD, N_CTX, |
| BLOCK_M=chunk_size_q, |
| BLOCK_N_DQ_LG=chunk_size_k, |
| HEAD_DIM=HEAD_DIM, |
| SPARSITY=sparsity, |
| **kernel_config |
| ) |
|
|
| config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default' |
| if chunk_size_k > 128: |
| bwd_dq_func = _attn_bwd_dq_bsa_varlen_wrapper |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_preset[config_key] |
| else: |
| bwd_dq_func = _attn_bwd_dq_bsa_varlen_align_wrapper |
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_align_preset[config_key] |
| |
| grid_dq = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), 1, q.shape[0] * q.shape[1]) |
| bwd_dq_func[grid_dq]( |
| q, arg_k, v, |
| do, |
| dq, |
| M, |
| delta, |
| block_indices, |
| block_indices_lens, |
| q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
| dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), |
| do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
| M.stride(0), M.stride(1), M.stride(2), |
| delta.stride(0), delta.stride(1), delta.stride(2), |
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
| block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2), |
| N_HEAD, N_CTX, |
| BLOCK_M=chunk_size_q, |
| BLOCK_N_DQ_LG=chunk_size_k, |
| HEAD_DIM=HEAD_DIM, |
| SPARSITY=sparsity, |
| **kernel_config |
| ) |
|
|
|
|
| def make_block_indices_varlen_cp_list(block_indices, cp_size, num_blocks_k_full): |
| """ |
| Args: |
| block_indices: [B, H, num_blocks_q_per_cp_rank, num_blocks_k_full] |
| |
| Return: |
| a list of [block_indices, block_indices_lens] for k from each cp_rank |
| - each block_indices starts from zero |
| - block_indices_lens indicates the valid number of elements in the last dimension of block_indices |
| """ |
| res = [] |
| num_blocks_per_rank = num_blocks_k_full // cp_size |
| for i in range(cp_size): |
| block_indices_tmp = block_indices.clone() |
| min_block_idx = i * num_blocks_per_rank |
| block_indices_tmp -= min_block_idx |
| block_indices_tmp[block_indices_tmp < 0] = num_blocks_per_rank |
| |
| block_indices_tmp, _ = torch.sort(block_indices_tmp, dim=-1) |
| |
| block_indices_tmp_lens = (block_indices_tmp < num_blocks_per_rank).sum(dim=-1) |
| |
| res.append([block_indices_tmp, block_indices_tmp_lens]) |
|
|
| return res |
|
|
|
|
| def flash_attn_fwd_softmax_lse_correction( |
| softmax_lse: torch.Tensor, |
| softmax_lse_per_step: torch.Tensor, |
| ): |
| """Merge softmax stats of each step in Attention with context parallelism""" |
| max_scale = torch.max(softmax_lse, softmax_lse_per_step) |
| min_scale = torch.min(softmax_lse, softmax_lse_per_step) |
| lse_diff = min_scale - max_scale |
| lse_diff = lse_diff.nan_to_num(nan=0.) |
| new_scale = max_scale + torch.log1p(torch.exp(lse_diff)) |
| softmax_lse.copy_(new_scale) |
|
|
|
|
| def flash_attn_fwd_out_correction_init( |
| out_init_step: torch.Tensor, |
| softmax_lse: torch.Tensor, |
| softmax_lse_init_step: torch.Tensor, |
| ): |
| """Merge partial outputs of the first step in Attention with context parallelism""" |
| softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse) |
| softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) |
| out_corrected = out_init_step * softmax_lse_corrected_exp |
| return out_corrected.to(out_init_step.dtype) |
|
|
|
|
|
|
| def flash_attn_fwd_out_correction( |
| out: torch.Tensor, |
| out_per_step: torch.Tensor, |
| softmax_lse: torch.Tensor, |
| softmax_lse_per_step: torch.Tensor, |
| ): |
| """Merge partial outputs of each step in Attention with context parallelism""" |
| softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse) |
| softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) |
| out_corrected = out_per_step * softmax_lse_corrected_exp |
| out.add_(out_corrected) |
|
|
|
|
| def topk_sort(score, num_chunks_selected): |
| block_indices = torch.topk(score, num_chunks_selected)[1] |
| block_indices, _ = torch.sort(block_indices, dim=-1) |
| return block_indices |
|
|
| class _attention_bsa(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, sm_scale, use_tma=False): |
| |
| HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] |
| |
| HEAD_DIM_V = v.shape[-1] |
| assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V |
| assert HEAD_DIM_K in {16, 32, 64, 128, 256} |
| |
| |
| q_cmp = mean_pooling_compression(q, chunk_size_q) |
| k_cmp = mean_pooling_compression(k, chunk_size_k) |
| block_indices, block_indices_lens = get_select_indices(q_cmp, k_cmp, sparsity, cdf_threshold) |
|
|
| |
|
|
| o, lse = attn_fwd_bsa_varlen_triton( |
| q, k, v, |
| sm_scale, block_indices, block_indices_lens, |
| chunk_size_q, chunk_size_k, |
| sparsity |
| ) |
|
|
| ctx.save_for_backward(q, k, v, o, lse, block_indices, block_indices_lens) |
| ctx.sm_scale = sm_scale |
| ctx.HEAD_DIM = HEAD_DIM_K |
| ctx.chunk_size_q = chunk_size_q |
| ctx.chunk_size_k = chunk_size_k |
| ctx.use_tma = use_tma |
| ctx.sparsity = sparsity |
|
|
| return o |
|
|
| @staticmethod |
| def backward(ctx, do): |
| q, k, v, o, lse, block_indices, block_indices_lens = ctx.saved_tensors |
|
|
| dq = torch.empty_like(q) |
| dk = torch.empty_like(k) |
| dv = torch.empty_like(v) |
|
|
| attn_bwd_bsa_varlen_triton( |
| do, |
| q, |
| k, |
| v, |
| o, |
| dq, |
| dk, |
| dv, |
| ctx.sm_scale, |
| lse, |
| block_indices, |
| block_indices_lens, |
| ctx.chunk_size_q, |
| ctx.chunk_size_k, |
| ctx.sparsity |
| ) |
|
|
| return dq, dk, dv, None, None, None, None, None, None |
|
|
| flash_attn_bsa = _attention_bsa.apply |
|
|
| def rearrange_THW_to_3d_block(x, Nt, Nh, Nw, t, h, w, D): |
| B, H, _, D = x.shape |
| x = x.view(B, H, Nt, t, Nh, h, Nw, w, D) |
| x = x.permute(0, 1, 2, 4, 6, 3, 5, 7, 8) |
| return x.contiguous().view(B, H, Nt * Nh * Nw * t * h * w, D) |
|
|
| def rearrange_3d_block_to_THW(x, Nt, Nh, Nw, t, h, w, D): |
| B, H, _, D = x.shape |
| x = x.view(B, H, Nt, Nh, Nw, t, h, w, D) |
| x = x.permute(0, 1, 2, 5, 3, 6, 4, 7, 8) |
| return x.contiguous().view(B, H, Nt * t * Nh * h * Nw * w, D) |
|
|
| def flash_attn_bsa_3d( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| latent_shape_q, |
| latent_shape_k, |
| |
| sparsity=0.875, |
| cdf_threshold=None, |
| chunk_3d_shape_q=[4, 4, 8], |
| chunk_3d_shape_k=[4, 4, 8], |
| ) -> torch.Tensor: |
| _, _, Sq, head_dim_q = q.shape |
| _, _, Sk, head_dim_k = k.shape |
| |
| assert head_dim_q == head_dim_k |
| head_dim = head_dim_q |
| |
| Tq, Hq, Wq = latent_shape_q |
| Tk, Hk, Wk = latent_shape_k |
| |
| assert Tq * Hq * Wq == Sq |
| assert Tk * Hk * Wk == Sk |
| |
| tq, hq, wq = chunk_3d_shape_q |
| tk, hk, wk = chunk_3d_shape_k |
| |
| assert Tq % tq == 0 and Hq % hq == 0 and Wq % wq == 0 |
| assert Tk % tk == 0 and Hk % hk == 0 and Wk % wk == 0 |
| |
| Ntq = Tq // tq |
| Nhq = Hq // hq |
| Nwq = Wq // wq |
|
|
| Ntk = Tk // tk |
| Nhk = Hk // hk |
| Nwk = Wk // wk |
|
|
| q = rearrange_THW_to_3d_block(q, Ntq, Nhq, Nwq, tq, hq, wq, q.shape[-1]) |
| k = rearrange_THW_to_3d_block(k, Ntk, Nhk, Nwk, tk, hk, wk, k.shape[-1]) |
| v = rearrange_THW_to_3d_block(v, Ntk, Nhk, Nwk, tk, hk, wk, v.shape[-1]) |
|
|
| chunk_size_q = tq * hq * wq |
| chunk_size_k = tk * hk * wk |
|
|
| output = flash_attn_bsa(q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, 1 / head_dim**0.5) |
|
|
| output = rearrange_3d_block_to_THW(output, Ntq, Nhq, Nwq, tq, hq, wq, output.shape[-1]) |
| return output |
|
|