| import os
|
| import torch
|
| import torch.nn.functional as F
|
| import triton
|
| import triton.language as tl
|
| import math
|
|
|
| from .common import _attn_fwd_gating, _attn_bwd_preprocess, configs_gating_preset
|
| from .flash_attn_bsa_varlen_mask import (
|
| _attn_fwd_bsa_varlen, _attn_fwd_bsa_varlen_align, _attn_bwd_dkdv_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_align_wrapper,
|
| configs_fwd_bsa_varlen_preset, configs_fwd_bsa_varlen_align_preset, configs_bwd_dkdv_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_align_preset
|
| )
|
|
|
| from .communicate import p2p_communicate
|
|
|
| torch._dynamo.config.cache_size_limit = 32
|
|
|
| def is_cuda():
|
| return triton.runtime.driver.active.get_current_target().backend == "cuda"
|
|
|
| def supports_tma():
|
| return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
|
|
|
| HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
|
|
|
| if HAS_TMA_DESC:
|
| print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
|
| else:
|
| print("TMA benchmarks will be running without grid constant TMA descriptor.", )
|
|
|
|
|
|
|
| class TmaAutoTuneHelper:
|
|
|
|
|
| class KernelParamWrapper:
|
|
|
| def __init__(self, desc):
|
| self.desc = desc
|
|
|
| def tma_desc_cpu_ptr(self):
|
| return self.desc.data_ptr()
|
|
|
| TMA_SIZE = 128
|
|
|
| def __init__(self):
|
| self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
|
| self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
|
| if HAS_TMA_DESC:
|
| self.descriptors = {}
|
| else:
|
| self.cuda_descriptors = {}
|
|
|
|
|
| def init_tma_descriptor(self, name):
|
| if HAS_TMA_DESC:
|
| self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
|
| else:
|
| self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
|
|
|
|
|
| def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
|
| if HAS_TMA_DESC:
|
| desc_x = self.descriptors[name]
|
| assert desc_x.data_ptr() % 64 == 0
|
| self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
|
| else:
|
| desc_x = self.cuda_descriptors[name]
|
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
| self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
|
| desc_x.copy_(buf_x, non_blocking=True)
|
|
|
|
|
| def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
|
| if HAS_TMA_DESC:
|
| desc_x = self.descriptors[name]
|
| assert desc_x.data_ptr() % 64 == 0
|
| self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
|
| else:
|
| desc_x = self.cuda_descriptors[name]
|
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
| self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
|
| desc_x.copy_(buf_x, non_blocking=True)
|
|
|
| def get_tma_descriptor_kernel_param(self, name):
|
| if HAS_TMA_DESC:
|
| assert self.descriptors[name] is not None
|
| return self.KernelParamWrapper(self.descriptors[name])
|
| else:
|
| assert self.cuda_descriptors[name] is not None
|
| return self.cuda_descriptors[name]
|
|
|
|
|
| @triton.jit
|
| def create_mask_from_indices_kernel(
|
| block_indices,
|
| block_mask,
|
| stride_bz, stride_bh, stride_bm, stride_bs,
|
| stride_mz, stride_mh, stride_mm, stride_mn,
|
| H,
|
| ):
|
| i_zh, i_m, i_s = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| i_z, i_h = i_zh // H, i_zh % H
|
|
|
| off_b = i_z.to(tl.int64) * stride_bz + i_h.to(tl.int64) * stride_bh + i_m.to(tl.int64) * stride_bm + i_s.to(tl.int64) * stride_bs
|
|
|
| b_i = tl.load(block_indices + off_b)
|
|
|
| off_m = i_z.to(tl.int64) * stride_mz + i_h.to(tl.int64) * stride_mh + i_m.to(tl.int64) * stride_mm + b_i.to(tl.int64) * stride_mn
|
|
|
| b_m = 1
|
| tl.store(block_mask + off_m, b_m.to(block_mask.dtype.element_ty))
|
|
|
| def create_mask_from_indices_triton(
|
| block_indices,
|
| N_cols
|
| ):
|
| B, H, N_rows, S = block_indices.shape
|
| block_mask = torch.zeros((B, H, N_rows, N_cols), dtype=torch.bool, device=block_indices.device)
|
| create_mask_from_indices_kernel[(B * H, N_rows, S)](
|
| block_indices,
|
| block_mask,
|
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
|
| block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3),
|
| H,
|
| )
|
| return block_mask
|
|
|
|
|
| def create_mask_from_indices_varlen(block_indices, N_cols_mask):
|
|
|
| B, H, M, _ = block_indices.shape
|
| device = block_indices.device
|
|
|
| mask = torch.zeros((B, H, M, N_cols_mask), dtype=torch.bool, device=device)
|
|
|
| valid = block_indices < N_cols_mask
|
|
|
| b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(block_indices)
|
| h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(block_indices)
|
| m_idx = torch.arange(M, device=device)[None, None, :, None].expand_as(block_indices)
|
|
|
| valid_coords = (b_idx[valid], h_idx[valid], m_idx[valid], block_indices[valid])
|
|
|
| mask[valid_coords] = True
|
|
|
| return mask
|
|
|
|
|
| def create_indices_k_from_indices_q_varlen(
|
| block_indices,
|
| N_cols_mask
|
| ):
|
| block_mask_qk = create_mask_from_indices_varlen(block_indices, N_cols_mask)
|
| B, H, M, N = block_mask_qk.shape
|
| block_mask_kq = block_mask_qk.permute(0, 1, 3, 2)
|
| indices = torch.arange(M, device=block_indices.device).view(1, 1, 1, -1).expand_as(block_mask_kq)
|
| block_indices_k = torch.where(block_mask_kq, indices, M)
|
| block_indices_k, _ = torch.sort(block_indices_k, dim=-1)
|
|
|
| block_indices_k_lens = (block_indices_k < M).sum(dim=-1)
|
|
|
| return block_indices_k, block_indices_k_lens
|
|
|
|
|
|
|
| def mean_pooling_compression(
|
| x: torch.Tensor,
|
| block_size: int
|
| ) -> torch.Tensor:
|
| B, H, S = x.shape[:3]
|
| num_block = math.ceil(S / block_size)
|
| if S % block_size != 0:
|
| x = F.pad(x, (0, 0, 0, num_block * block_size - S))
|
| x_cmp = x.view(B, H, num_block, block_size, -1).mean(dim=3)
|
| return x_cmp
|
|
|
|
|
| def cal_score(q, k):
|
| k_transposed = k.transpose(-1, -2)
|
| score = torch.matmul(q, k_transposed)
|
| return score
|
|
|
| def cal_score_triton(q, k):
|
| B, H, s_q, D = q.shape
|
| s_k = k.shape[2]
|
|
|
| score = torch.empty(B, H, s_q, s_k, device=q.device, dtype=q.dtype)
|
|
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_gating_preset['default']
|
|
|
| grid = lambda args: (triton.cdiv(s_q, args["BLOCK_M"]), B * H, 1)
|
| _attn_fwd_gating[grid](
|
| q, k, score,
|
| q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| score.stride(0), score.stride(1), score.stride(2), score.stride(3),
|
| H, s_q, s_k,
|
| HEAD_DIM=D,
|
| **kernel_config
|
| )
|
| return score
|
|
|
|
|
| def get_select_indices_topk(q, k, sparsity):
|
| score = cal_score(q, k)
|
| block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
|
| return block_indices, block_indices_lens
|
|
|
|
|
| def get_select_indices_topk_from_score(score, sparsity):
|
| num_selected = int((1 - sparsity) * score.shape[-1])
|
| block_indices = torch.topk(score, num_selected)[1]
|
|
|
| block_indices_lens = torch.full(
|
| (block_indices.shape[0], block_indices.shape[1], block_indices.shape[2]),
|
| num_selected,
|
| dtype=torch.int32,
|
| device=block_indices.device
|
| )
|
|
|
| return block_indices, block_indices_lens
|
|
|
|
|
| def get_select_indices_cdf(q, k, cdf_threshold):
|
| score = cal_score(q, k)
|
| head_dim = q.shape[-1]
|
| block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold, 1 / head_dim**0.5)
|
| return block_indices, block_indices_lens
|
|
|
|
|
| def get_select_indices_cdf_from_score(score, cdf_threshold, sm_scale):
|
| weights = torch.softmax(score * sm_scale, dim=-1)
|
|
|
| B, H, Sq, Sk = weights.shape
|
| cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
|
| weights_sorted = torch.sort(weights, dim=-1, descending=True)
|
| cdf = torch.cumsum(weights_sorted.values, dim=-1)
|
| num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
|
|
|
| return weights_sorted.indices, num_selected.squeeze(-1)
|
|
|
|
|
| def get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold):
|
| score = cal_score(q, k)
|
| head_dim = q.shape[-1]
|
| block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, 1 / head_dim**0.5)
|
| return block_indices, block_indices_lens
|
|
|
|
|
| def get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, sm_scale):
|
| weights = torch.softmax(score * sm_scale, dim=-1)
|
|
|
| B, H, Sq, Sk = weights.shape
|
| cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
|
| weights_sorted = torch.sort(weights, dim=-1, descending=True)
|
| cdf = torch.cumsum(weights_sorted.values, dim=-1)
|
| num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
|
|
|
|
|
| num_selected_topk = int((1 - sparsity) * score.shape[-1])
|
| num_selected[num_selected < num_selected_topk] = num_selected_topk
|
|
|
| return weights_sorted.indices, num_selected.squeeze(-1)
|
|
|
| def get_select_indices(q, k, sparsity, cdf_threshold):
|
| if sparsity is not None and cdf_threshold is None:
|
| block_indices, block_indices_lens = get_select_indices_topk(q, k, sparsity)
|
| elif sparsity is None and cdf_threshold is not None:
|
| block_indices, block_indices_lens = get_select_indices_cdf(q, k, cdf_threshold)
|
| elif sparsity is not None and cdf_threshold is not None:
|
| block_indices, block_indices_lens = get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold)
|
| else:
|
| raise ValueError
|
| return block_indices, block_indices_lens
|
|
|
| def get_select_indices_from_score(score, sparsity, cdf_threshold):
|
| if sparsity is not None and cdf_threshold is None:
|
| block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
|
| elif sparsity is None and cdf_threshold is not None:
|
| block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold)
|
| elif sparsity is not None and cdf_threshold is not None:
|
| block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold)
|
| else:
|
| raise ValueError
|
| return block_indices, block_indices_lens
|
|
|
| def attn_fwd_bsa_varlen_triton(
|
| q,
|
| k,
|
| v,
|
| sm_scale,
|
| block_indices,
|
| block_indices_lens,
|
| chunk_size_q,
|
| chunk_size_k,
|
| sparsity
|
| ):
|
|
|
| B, H, Seq, D = q.shape
|
|
|
| o = torch.empty_like(q)
|
| M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
|
| grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
|
|
|
| config_key = 'BLOCK_N_LG=64' if chunk_size_k == 64 else 'default'
|
| if chunk_size_k > 128:
|
| fwd_func = _attn_fwd_bsa_varlen
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_preset[config_key]
|
| else:
|
| fwd_func = _attn_fwd_bsa_varlen_align
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_align_preset[config_key]
|
|
|
| block_indices = block_indices.contiguous()
|
| block_indices_lens = block_indices_lens.contiguous()
|
|
|
| fwd_func[grid](
|
| q, k, v, sm_scale, M, o,
|
| block_indices,
|
| block_indices_lens,
|
| q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
| o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
|
| block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
|
| H, Seq,
|
| D,
|
| BLOCK_M=chunk_size_q,
|
| BLOCK_N_LG=chunk_size_k,
|
| SPARSITY=sparsity,
|
| **kernel_config
|
| )
|
|
|
| LN2 = 0.6931471824645996
|
| lse = M * LN2
|
|
|
| return o, lse
|
|
|
| def attn_bwd_bsa_varlen_triton(
|
| do,
|
| q,
|
| k,
|
| v,
|
| o,
|
| dq,
|
| dk,
|
| dv,
|
| sm_scale,
|
| M,
|
| block_indices,
|
| block_indices_lens,
|
| chunk_size_q,
|
| chunk_size_k,
|
| sparsity
|
| ):
|
| RCP_LN2 = 1.4426950408889634
|
| M = M * RCP_LN2
|
|
|
| do = do.contiguous()
|
|
|
|
|
| BATCH, N_HEAD, N_CTX, HEAD_DIM = q.shape
|
| N_CTX_KV = k.shape[-2]
|
|
|
| RCP_LN2 = 1.4426950408889634
|
| arg_k = k
|
| arg_k = arg_k * (sm_scale * RCP_LN2)
|
|
|
| if min(chunk_size_q, chunk_size_k) >= 128:
|
| PRE_BLOCK = 128
|
| else:
|
| PRE_BLOCK = min(chunk_size_q, chunk_size_k)
|
|
|
| assert N_CTX % PRE_BLOCK == 0
|
| pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
| delta = torch.empty_like(M)
|
| _attn_bwd_preprocess[pre_grid](
|
| o, do,
|
| delta,
|
| N_CTX,
|
| BLOCK_M=PRE_BLOCK,
|
| HEAD_DIM=HEAD_DIM
|
| )
|
|
|
| block_indices_k, block_indices_k_lens = create_indices_k_from_indices_q_varlen(
|
| block_indices=block_indices,
|
| N_cols_mask=N_CTX_KV // chunk_size_k
|
| )
|
|
|
| block_indices = block_indices.contiguous()
|
| block_indices_lens = block_indices_lens.contiguous()
|
| block_indices_k = block_indices_k.contiguous()
|
| block_indices_k_lens = block_indices_k_lens.contiguous()
|
|
|
| config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dkdv_bsa_varlen_preset[config_key]
|
|
|
| grid_dkdv = lambda args: (triton.cdiv(arg_k.shape[2], args["BLOCK_N"]), 1, arg_k.shape[0] * arg_k.shape[1])
|
| _attn_bwd_dkdv_bsa_varlen_wrapper[grid_dkdv](
|
| q, arg_k, v, sm_scale,
|
| do,
|
| dk, dv,
|
| M,
|
| delta,
|
| block_indices_k,
|
| block_indices_k_lens,
|
| q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
| dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
|
| dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
|
| do.stride(0), do.stride(1), do.stride(2), do.stride(3),
|
| M.stride(0), M.stride(1), M.stride(2),
|
| delta.stride(0), delta.stride(1), delta.stride(2),
|
| block_indices_k.stride(0), block_indices_k.stride(1), block_indices_k.stride(2), block_indices_k.stride(3),
|
| block_indices_k_lens.stride(0), block_indices_k_lens.stride(1), block_indices_k_lens.stride(2),
|
| N_HEAD, N_CTX,
|
| BLOCK_M=chunk_size_q,
|
| BLOCK_N_DQ_LG=chunk_size_k,
|
| HEAD_DIM=HEAD_DIM,
|
| SPARSITY=sparsity,
|
| **kernel_config
|
| )
|
|
|
| config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
|
| if chunk_size_k > 128:
|
| bwd_dq_func = _attn_bwd_dq_bsa_varlen_wrapper
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_preset[config_key]
|
| else:
|
| bwd_dq_func = _attn_bwd_dq_bsa_varlen_align_wrapper
|
| kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_align_preset[config_key]
|
|
|
| grid_dq = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), 1, q.shape[0] * q.shape[1])
|
| bwd_dq_func[grid_dq](
|
| q, arg_k, v,
|
| do,
|
| dq,
|
| M,
|
| delta,
|
| block_indices,
|
| block_indices_lens,
|
| q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
| dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
|
| do.stride(0), do.stride(1), do.stride(2), do.stride(3),
|
| M.stride(0), M.stride(1), M.stride(2),
|
| delta.stride(0), delta.stride(1), delta.stride(2),
|
| block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
|
| block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
|
| N_HEAD, N_CTX,
|
| BLOCK_M=chunk_size_q,
|
| BLOCK_N_DQ_LG=chunk_size_k,
|
| HEAD_DIM=HEAD_DIM,
|
| SPARSITY=sparsity,
|
| **kernel_config
|
| )
|
|
|
|
|
| def make_block_indices_varlen_cp_list(block_indices, cp_size, num_blocks_k_full):
|
| """
|
| Args:
|
| block_indices: [B, H, num_blocks_q_per_cp_rank, num_blocks_k_full]
|
|
|
| Return:
|
| a list of [block_indices, block_indices_lens] for k from each cp_rank
|
| - each block_indices starts from zero
|
| - block_indices_lens indicates the valid number of elements in the last dimension of block_indices
|
| """
|
| res = []
|
| num_blocks_per_rank = num_blocks_k_full // cp_size
|
| for i in range(cp_size):
|
| block_indices_tmp = block_indices.clone()
|
| min_block_idx = i * num_blocks_per_rank
|
| block_indices_tmp -= min_block_idx
|
| block_indices_tmp[block_indices_tmp < 0] = num_blocks_per_rank
|
|
|
| block_indices_tmp, _ = torch.sort(block_indices_tmp, dim=-1)
|
|
|
| block_indices_tmp_lens = (block_indices_tmp < num_blocks_per_rank).sum(dim=-1)
|
|
|
| res.append([block_indices_tmp, block_indices_tmp_lens])
|
|
|
| return res
|
|
|
|
|
| def flash_attn_fwd_softmax_lse_correction(
|
| softmax_lse: torch.Tensor,
|
| softmax_lse_per_step: torch.Tensor,
|
| ):
|
| """Merge softmax stats of each step in Attention with context parallelism"""
|
| max_scale = torch.max(softmax_lse, softmax_lse_per_step)
|
| min_scale = torch.min(softmax_lse, softmax_lse_per_step)
|
| lse_diff = min_scale - max_scale
|
| lse_diff = lse_diff.nan_to_num(nan=0.)
|
| new_scale = max_scale + torch.log1p(torch.exp(lse_diff))
|
| softmax_lse.copy_(new_scale)
|
|
|
|
|
| def flash_attn_fwd_out_correction_init(
|
| out_init_step: torch.Tensor,
|
| softmax_lse: torch.Tensor,
|
| softmax_lse_init_step: torch.Tensor,
|
| ):
|
| """Merge partial outputs of the first step in Attention with context parallelism"""
|
| softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse)
|
| softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
|
| out_corrected = out_init_step * softmax_lse_corrected_exp
|
| return out_corrected.to(out_init_step.dtype)
|
|
|
|
|
|
|
| def flash_attn_fwd_out_correction(
|
| out: torch.Tensor,
|
| out_per_step: torch.Tensor,
|
| softmax_lse: torch.Tensor,
|
| softmax_lse_per_step: torch.Tensor,
|
| ):
|
| """Merge partial outputs of each step in Attention with context parallelism"""
|
| softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse)
|
| softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
|
| out_corrected = out_per_step * softmax_lse_corrected_exp
|
| out.add_(out_corrected)
|
|
|
|
|
| def topk_sort(score, num_chunks_selected):
|
| block_indices = torch.topk(score, num_chunks_selected)[1]
|
| block_indices, _ = torch.sort(block_indices, dim=-1)
|
| return block_indices
|
|
|
| class _attention_bsa(torch.autograd.Function):
|
|
|
| @staticmethod
|
| def forward(ctx, q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, sm_scale, use_tma=False):
|
|
|
| HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
|
|
|
| HEAD_DIM_V = v.shape[-1]
|
| assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
|
| assert HEAD_DIM_K in {16, 32, 64, 128, 256}
|
|
|
|
|
| q_cmp = mean_pooling_compression(q, chunk_size_q)
|
| k_cmp = mean_pooling_compression(k, chunk_size_k)
|
| block_indices, block_indices_lens = get_select_indices(q_cmp, k_cmp, sparsity, cdf_threshold)
|
|
|
|
|
|
|
| o, lse = attn_fwd_bsa_varlen_triton(
|
| q, k, v,
|
| sm_scale, block_indices, block_indices_lens,
|
| chunk_size_q, chunk_size_k,
|
| sparsity
|
| )
|
|
|
| ctx.save_for_backward(q, k, v, o, lse, block_indices, block_indices_lens)
|
| ctx.sm_scale = sm_scale
|
| ctx.HEAD_DIM = HEAD_DIM_K
|
| ctx.chunk_size_q = chunk_size_q
|
| ctx.chunk_size_k = chunk_size_k
|
| ctx.use_tma = use_tma
|
| ctx.sparsity = sparsity
|
|
|
| return o
|
|
|
| @staticmethod
|
| def backward(ctx, do):
|
| q, k, v, o, lse, block_indices, block_indices_lens = ctx.saved_tensors
|
|
|
| dq = torch.empty_like(q)
|
| dk = torch.empty_like(k)
|
| dv = torch.empty_like(v)
|
|
|
| attn_bwd_bsa_varlen_triton(
|
| do,
|
| q,
|
| k,
|
| v,
|
| o,
|
| dq,
|
| dk,
|
| dv,
|
| ctx.sm_scale,
|
| lse,
|
| block_indices,
|
| block_indices_lens,
|
| ctx.chunk_size_q,
|
| ctx.chunk_size_k,
|
| ctx.sparsity
|
| )
|
|
|
| return dq, dk, dv, None, None, None, None, None, None
|
|
|
| flash_attn_bsa = _attention_bsa.apply
|
|
|
| def rearrange_THW_to_3d_block(x, Nt, Nh, Nw, t, h, w, D):
|
| B, H, _, D = x.shape
|
| x = x.view(B, H, Nt, t, Nh, h, Nw, w, D)
|
| x = x.permute(0, 1, 2, 4, 6, 3, 5, 7, 8)
|
| return x.contiguous().view(B, H, Nt * Nh * Nw * t * h * w, D)
|
|
|
| def rearrange_3d_block_to_THW(x, Nt, Nh, Nw, t, h, w, D):
|
| B, H, _, D = x.shape
|
| x = x.view(B, H, Nt, Nh, Nw, t, h, w, D)
|
| x = x.permute(0, 1, 2, 5, 3, 6, 4, 7, 8)
|
| return x.contiguous().view(B, H, Nt * t * Nh * h * Nw * w, D)
|
|
|
| def flash_attn_bsa_3d(
|
| q: torch.Tensor,
|
| k: torch.Tensor,
|
| v: torch.Tensor,
|
| latent_shape_q,
|
| latent_shape_k,
|
|
|
| sparsity=0.875,
|
| cdf_threshold=None,
|
| chunk_3d_shape_q=[4, 4, 8],
|
| chunk_3d_shape_k=[4, 4, 8],
|
| ) -> torch.Tensor:
|
| _, _, Sq, head_dim_q = q.shape
|
| _, _, Sk, head_dim_k = k.shape
|
|
|
| assert head_dim_q == head_dim_k
|
| head_dim = head_dim_q
|
|
|
| Tq, Hq, Wq = latent_shape_q
|
| Tk, Hk, Wk = latent_shape_k
|
|
|
| assert Tq * Hq * Wq == Sq
|
| assert Tk * Hk * Wk == Sk
|
|
|
| tq, hq, wq = chunk_3d_shape_q
|
| tk, hk, wk = chunk_3d_shape_k
|
|
|
| assert Tq % tq == 0 and Hq % hq == 0 and Wq % wq == 0
|
| assert Tk % tk == 0 and Hk % hk == 0 and Wk % wk == 0
|
|
|
| Ntq = Tq // tq
|
| Nhq = Hq // hq
|
| Nwq = Wq // wq
|
|
|
| Ntk = Tk // tk
|
| Nhk = Hk // hk
|
| Nwk = Wk // wk
|
|
|
| q = rearrange_THW_to_3d_block(q, Ntq, Nhq, Nwq, tq, hq, wq, q.shape[-1])
|
| k = rearrange_THW_to_3d_block(k, Ntk, Nhk, Nwk, tk, hk, wk, k.shape[-1])
|
| v = rearrange_THW_to_3d_block(v, Ntk, Nhk, Nwk, tk, hk, wk, v.shape[-1])
|
|
|
| chunk_size_q = tq * hq * wq
|
| chunk_size_k = tk * hk * wk
|
|
|
| output = flash_attn_bsa(q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, 1 / head_dim**0.5)
|
|
|
| output = rearrange_3d_block_to_THW(output, Ntq, Nhq, Nwq, tq, hq, wq, output.shape[-1])
|
| return output
|
|
|