|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import triton |
|
|
import triton.language as tl |
|
|
import math |
|
|
|
|
|
from .common import _attn_fwd_gating, _attn_bwd_preprocess, configs_gating_preset |
|
|
from .flash_attn_bsa_varlen_mask import ( |
|
|
_attn_fwd_bsa_varlen, _attn_fwd_bsa_varlen_align, _attn_bwd_dkdv_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_align_wrapper, |
|
|
configs_fwd_bsa_varlen_preset, configs_fwd_bsa_varlen_align_preset, configs_bwd_dkdv_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_align_preset |
|
|
) |
|
|
|
|
|
from .communicate import p2p_communicate |
|
|
|
|
|
torch._dynamo.config.cache_size_limit = 32 |
|
|
|
|
|
def is_cuda(): |
|
|
return triton.runtime.driver.active.get_current_target().backend == "cuda" |
|
|
|
|
|
def supports_tma(): |
|
|
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 |
|
|
|
|
|
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) |
|
|
|
|
|
if HAS_TMA_DESC: |
|
|
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) |
|
|
else: |
|
|
print("TMA benchmarks will be running without grid constant TMA descriptor.", ) |
|
|
|
|
|
|
|
|
|
|
|
class TmaAutoTuneHelper: |
|
|
|
|
|
|
|
|
class KernelParamWrapper: |
|
|
|
|
|
def __init__(self, desc): |
|
|
self.desc = desc |
|
|
|
|
|
def tma_desc_cpu_ptr(self): |
|
|
return self.desc.data_ptr() |
|
|
|
|
|
TMA_SIZE = 128 |
|
|
|
|
|
def __init__(self): |
|
|
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) |
|
|
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) |
|
|
if HAS_TMA_DESC: |
|
|
self.descriptors = {} |
|
|
else: |
|
|
self.cuda_descriptors = {} |
|
|
|
|
|
|
|
|
def init_tma_descriptor(self, name): |
|
|
if HAS_TMA_DESC: |
|
|
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) |
|
|
else: |
|
|
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) |
|
|
|
|
|
|
|
|
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): |
|
|
if HAS_TMA_DESC: |
|
|
desc_x = self.descriptors[name] |
|
|
assert desc_x.data_ptr() % 64 == 0 |
|
|
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) |
|
|
else: |
|
|
desc_x = self.cuda_descriptors[name] |
|
|
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
|
|
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) |
|
|
desc_x.copy_(buf_x, non_blocking=True) |
|
|
|
|
|
|
|
|
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): |
|
|
if HAS_TMA_DESC: |
|
|
desc_x = self.descriptors[name] |
|
|
assert desc_x.data_ptr() % 64 == 0 |
|
|
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) |
|
|
else: |
|
|
desc_x = self.cuda_descriptors[name] |
|
|
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) |
|
|
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) |
|
|
desc_x.copy_(buf_x, non_blocking=True) |
|
|
|
|
|
def get_tma_descriptor_kernel_param(self, name): |
|
|
if HAS_TMA_DESC: |
|
|
assert self.descriptors[name] is not None |
|
|
return self.KernelParamWrapper(self.descriptors[name]) |
|
|
else: |
|
|
assert self.cuda_descriptors[name] is not None |
|
|
return self.cuda_descriptors[name] |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def create_mask_from_indices_kernel( |
|
|
block_indices, |
|
|
block_mask, |
|
|
stride_bz, stride_bh, stride_bm, stride_bs, |
|
|
stride_mz, stride_mh, stride_mm, stride_mn, |
|
|
H, |
|
|
): |
|
|
i_zh, i_m, i_s = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
i_z, i_h = i_zh // H, i_zh % H |
|
|
|
|
|
off_b = i_z.to(tl.int64) * stride_bz + i_h.to(tl.int64) * stride_bh + i_m.to(tl.int64) * stride_bm + i_s.to(tl.int64) * stride_bs |
|
|
|
|
|
b_i = tl.load(block_indices + off_b) |
|
|
|
|
|
off_m = i_z.to(tl.int64) * stride_mz + i_h.to(tl.int64) * stride_mh + i_m.to(tl.int64) * stride_mm + b_i.to(tl.int64) * stride_mn |
|
|
|
|
|
b_m = 1 |
|
|
tl.store(block_mask + off_m, b_m.to(block_mask.dtype.element_ty)) |
|
|
|
|
|
def create_mask_from_indices_triton( |
|
|
block_indices, |
|
|
N_cols |
|
|
): |
|
|
B, H, N_rows, S = block_indices.shape |
|
|
block_mask = torch.zeros((B, H, N_rows, N_cols), dtype=torch.bool, device=block_indices.device) |
|
|
create_mask_from_indices_kernel[(B * H, N_rows, S)]( |
|
|
block_indices, |
|
|
block_mask, |
|
|
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
|
|
block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), |
|
|
H, |
|
|
) |
|
|
return block_mask |
|
|
|
|
|
|
|
|
def create_mask_from_indices_varlen(block_indices, N_cols_mask): |
|
|
|
|
|
B, H, M, _ = block_indices.shape |
|
|
device = block_indices.device |
|
|
|
|
|
mask = torch.zeros((B, H, M, N_cols_mask), dtype=torch.bool, device=device) |
|
|
|
|
|
valid = block_indices < N_cols_mask |
|
|
|
|
|
b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(block_indices) |
|
|
h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(block_indices) |
|
|
m_idx = torch.arange(M, device=device)[None, None, :, None].expand_as(block_indices) |
|
|
|
|
|
valid_coords = (b_idx[valid], h_idx[valid], m_idx[valid], block_indices[valid]) |
|
|
|
|
|
mask[valid_coords] = True |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
def create_indices_k_from_indices_q_varlen( |
|
|
block_indices, |
|
|
N_cols_mask |
|
|
): |
|
|
block_mask_qk = create_mask_from_indices_varlen(block_indices, N_cols_mask) |
|
|
B, H, M, N = block_mask_qk.shape |
|
|
block_mask_kq = block_mask_qk.permute(0, 1, 3, 2) |
|
|
indices = torch.arange(M, device=block_indices.device).view(1, 1, 1, -1).expand_as(block_mask_kq) |
|
|
block_indices_k = torch.where(block_mask_kq, indices, M) |
|
|
block_indices_k, _ = torch.sort(block_indices_k, dim=-1) |
|
|
|
|
|
block_indices_k_lens = (block_indices_k < M).sum(dim=-1) |
|
|
|
|
|
return block_indices_k, block_indices_k_lens |
|
|
|
|
|
|
|
|
|
|
|
def mean_pooling_compression( |
|
|
x: torch.Tensor, |
|
|
block_size: int |
|
|
) -> torch.Tensor: |
|
|
B, H, S = x.shape[:3] |
|
|
num_block = math.ceil(S / block_size) |
|
|
if S % block_size != 0: |
|
|
x = F.pad(x, (0, 0, 0, num_block * block_size - S)) |
|
|
x_cmp = x.view(B, H, num_block, block_size, -1).mean(dim=3) |
|
|
return x_cmp |
|
|
|
|
|
|
|
|
def cal_score(q, k): |
|
|
k_transposed = k.transpose(-1, -2) |
|
|
score = torch.matmul(q, k_transposed) |
|
|
return score |
|
|
|
|
|
def cal_score_triton(q, k): |
|
|
B, H, s_q, D = q.shape |
|
|
s_k = k.shape[2] |
|
|
|
|
|
score = torch.empty(B, H, s_q, s_k, device=q.device, dtype=q.dtype) |
|
|
|
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_gating_preset['default'] |
|
|
|
|
|
grid = lambda args: (triton.cdiv(s_q, args["BLOCK_M"]), B * H, 1) |
|
|
_attn_fwd_gating[grid]( |
|
|
q, k, score, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
score.stride(0), score.stride(1), score.stride(2), score.stride(3), |
|
|
H, s_q, s_k, |
|
|
HEAD_DIM=D, |
|
|
**kernel_config |
|
|
) |
|
|
return score |
|
|
|
|
|
|
|
|
def get_select_indices_topk(q, k, sparsity): |
|
|
score = cal_score(q, k) |
|
|
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity) |
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
|
|
|
def get_select_indices_topk_from_score(score, sparsity): |
|
|
num_selected = int((1 - sparsity) * score.shape[-1]) |
|
|
block_indices = torch.topk(score, num_selected)[1] |
|
|
|
|
|
block_indices_lens = torch.full( |
|
|
(block_indices.shape[0], block_indices.shape[1], block_indices.shape[2]), |
|
|
num_selected, |
|
|
dtype=torch.int32, |
|
|
device=block_indices.device |
|
|
) |
|
|
|
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
|
|
|
def get_select_indices_cdf(q, k, cdf_threshold): |
|
|
score = cal_score(q, k) |
|
|
head_dim = q.shape[-1] |
|
|
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold, 1 / head_dim**0.5) |
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
|
|
|
def get_select_indices_cdf_from_score(score, cdf_threshold, sm_scale): |
|
|
weights = torch.softmax(score * sm_scale, dim=-1) |
|
|
|
|
|
B, H, Sq, Sk = weights.shape |
|
|
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1) |
|
|
weights_sorted = torch.sort(weights, dim=-1, descending=True) |
|
|
cdf = torch.cumsum(weights_sorted.values, dim=-1) |
|
|
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True) |
|
|
|
|
|
return weights_sorted.indices, num_selected.squeeze(-1) |
|
|
|
|
|
|
|
|
def get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold): |
|
|
score = cal_score(q, k) |
|
|
head_dim = q.shape[-1] |
|
|
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, 1 / head_dim**0.5) |
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
|
|
|
def get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, sm_scale): |
|
|
weights = torch.softmax(score * sm_scale, dim=-1) |
|
|
|
|
|
B, H, Sq, Sk = weights.shape |
|
|
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1) |
|
|
weights_sorted = torch.sort(weights, dim=-1, descending=True) |
|
|
cdf = torch.cumsum(weights_sorted.values, dim=-1) |
|
|
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True) |
|
|
|
|
|
|
|
|
num_selected_topk = int((1 - sparsity) * score.shape[-1]) |
|
|
num_selected[num_selected < num_selected_topk] = num_selected_topk |
|
|
|
|
|
return weights_sorted.indices, num_selected.squeeze(-1) |
|
|
|
|
|
def get_select_indices(q, k, sparsity, cdf_threshold): |
|
|
if sparsity is not None and cdf_threshold is None: |
|
|
block_indices, block_indices_lens = get_select_indices_topk(q, k, sparsity) |
|
|
elif sparsity is None and cdf_threshold is not None: |
|
|
block_indices, block_indices_lens = get_select_indices_cdf(q, k, cdf_threshold) |
|
|
elif sparsity is not None and cdf_threshold is not None: |
|
|
block_indices, block_indices_lens = get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold) |
|
|
else: |
|
|
raise ValueError |
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
def get_select_indices_from_score(score, sparsity, cdf_threshold): |
|
|
if sparsity is not None and cdf_threshold is None: |
|
|
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity) |
|
|
elif sparsity is None and cdf_threshold is not None: |
|
|
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold) |
|
|
elif sparsity is not None and cdf_threshold is not None: |
|
|
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold) |
|
|
else: |
|
|
raise ValueError |
|
|
return block_indices, block_indices_lens |
|
|
|
|
|
def attn_fwd_bsa_varlen_triton( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
sm_scale, |
|
|
block_indices, |
|
|
block_indices_lens, |
|
|
chunk_size_q, |
|
|
chunk_size_k, |
|
|
sparsity |
|
|
): |
|
|
|
|
|
B, H, Seq, D = q.shape |
|
|
|
|
|
o = torch.empty_like(q) |
|
|
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
|
|
|
|
|
grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) |
|
|
|
|
|
config_key = 'BLOCK_N_LG=64' if chunk_size_k == 64 else 'default' |
|
|
if chunk_size_k > 128: |
|
|
fwd_func = _attn_fwd_bsa_varlen |
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_preset[config_key] |
|
|
else: |
|
|
fwd_func = _attn_fwd_bsa_varlen_align |
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_align_preset[config_key] |
|
|
|
|
|
block_indices = block_indices.contiguous() |
|
|
block_indices_lens = block_indices_lens.contiguous() |
|
|
|
|
|
fwd_func[grid]( |
|
|
q, k, v, sm_scale, M, o, |
|
|
block_indices, |
|
|
block_indices_lens, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
|
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
|
|
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2), |
|
|
H, Seq, |
|
|
D, |
|
|
BLOCK_M=chunk_size_q, |
|
|
BLOCK_N_LG=chunk_size_k, |
|
|
SPARSITY=sparsity, |
|
|
**kernel_config |
|
|
) |
|
|
|
|
|
LN2 = 0.6931471824645996 |
|
|
lse = M * LN2 |
|
|
|
|
|
return o, lse |
|
|
|
|
|
def attn_bwd_bsa_varlen_triton( |
|
|
do, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
o, |
|
|
dq, |
|
|
dk, |
|
|
dv, |
|
|
sm_scale, |
|
|
M, |
|
|
block_indices, |
|
|
block_indices_lens, |
|
|
chunk_size_q, |
|
|
chunk_size_k, |
|
|
sparsity |
|
|
): |
|
|
RCP_LN2 = 1.4426950408889634 |
|
|
M = M * RCP_LN2 |
|
|
|
|
|
do = do.contiguous() |
|
|
|
|
|
|
|
|
BATCH, N_HEAD, N_CTX, HEAD_DIM = q.shape |
|
|
N_CTX_KV = k.shape[-2] |
|
|
|
|
|
RCP_LN2 = 1.4426950408889634 |
|
|
arg_k = k |
|
|
arg_k = arg_k * (sm_scale * RCP_LN2) |
|
|
|
|
|
if min(chunk_size_q, chunk_size_k) >= 128: |
|
|
PRE_BLOCK = 128 |
|
|
else: |
|
|
PRE_BLOCK = min(chunk_size_q, chunk_size_k) |
|
|
|
|
|
assert N_CTX % PRE_BLOCK == 0 |
|
|
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) |
|
|
delta = torch.empty_like(M) |
|
|
_attn_bwd_preprocess[pre_grid]( |
|
|
o, do, |
|
|
delta, |
|
|
N_CTX, |
|
|
BLOCK_M=PRE_BLOCK, |
|
|
HEAD_DIM=HEAD_DIM |
|
|
) |
|
|
|
|
|
block_indices_k, block_indices_k_lens = create_indices_k_from_indices_q_varlen( |
|
|
block_indices=block_indices, |
|
|
N_cols_mask=N_CTX_KV // chunk_size_k |
|
|
) |
|
|
|
|
|
block_indices = block_indices.contiguous() |
|
|
block_indices_lens = block_indices_lens.contiguous() |
|
|
block_indices_k = block_indices_k.contiguous() |
|
|
block_indices_k_lens = block_indices_k_lens.contiguous() |
|
|
|
|
|
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default' |
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dkdv_bsa_varlen_preset[config_key] |
|
|
|
|
|
grid_dkdv = lambda args: (triton.cdiv(arg_k.shape[2], args["BLOCK_N"]), 1, arg_k.shape[0] * arg_k.shape[1]) |
|
|
_attn_bwd_dkdv_bsa_varlen_wrapper[grid_dkdv]( |
|
|
q, arg_k, v, sm_scale, |
|
|
do, |
|
|
dk, dv, |
|
|
M, |
|
|
delta, |
|
|
block_indices_k, |
|
|
block_indices_k_lens, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), |
|
|
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), |
|
|
do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
|
|
M.stride(0), M.stride(1), M.stride(2), |
|
|
delta.stride(0), delta.stride(1), delta.stride(2), |
|
|
block_indices_k.stride(0), block_indices_k.stride(1), block_indices_k.stride(2), block_indices_k.stride(3), |
|
|
block_indices_k_lens.stride(0), block_indices_k_lens.stride(1), block_indices_k_lens.stride(2), |
|
|
N_HEAD, N_CTX, |
|
|
BLOCK_M=chunk_size_q, |
|
|
BLOCK_N_DQ_LG=chunk_size_k, |
|
|
HEAD_DIM=HEAD_DIM, |
|
|
SPARSITY=sparsity, |
|
|
**kernel_config |
|
|
) |
|
|
|
|
|
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default' |
|
|
if chunk_size_k > 128: |
|
|
bwd_dq_func = _attn_bwd_dq_bsa_varlen_wrapper |
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_preset[config_key] |
|
|
else: |
|
|
bwd_dq_func = _attn_bwd_dq_bsa_varlen_align_wrapper |
|
|
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_align_preset[config_key] |
|
|
|
|
|
grid_dq = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), 1, q.shape[0] * q.shape[1]) |
|
|
bwd_dq_func[grid_dq]( |
|
|
q, arg_k, v, |
|
|
do, |
|
|
dq, |
|
|
M, |
|
|
delta, |
|
|
block_indices, |
|
|
block_indices_lens, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), |
|
|
do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
|
|
M.stride(0), M.stride(1), M.stride(2), |
|
|
delta.stride(0), delta.stride(1), delta.stride(2), |
|
|
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3), |
|
|
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2), |
|
|
N_HEAD, N_CTX, |
|
|
BLOCK_M=chunk_size_q, |
|
|
BLOCK_N_DQ_LG=chunk_size_k, |
|
|
HEAD_DIM=HEAD_DIM, |
|
|
SPARSITY=sparsity, |
|
|
**kernel_config |
|
|
) |
|
|
|
|
|
|
|
|
def make_block_indices_varlen_cp_list(block_indices, cp_size, num_blocks_k_full): |
|
|
""" |
|
|
Args: |
|
|
block_indices: [B, H, num_blocks_q_per_cp_rank, num_blocks_k_full] |
|
|
|
|
|
Return: |
|
|
a list of [block_indices, block_indices_lens] for k from each cp_rank |
|
|
- each block_indices starts from zero |
|
|
- block_indices_lens indicates the valid number of elements in the last dimension of block_indices |
|
|
""" |
|
|
res = [] |
|
|
num_blocks_per_rank = num_blocks_k_full // cp_size |
|
|
for i in range(cp_size): |
|
|
block_indices_tmp = block_indices.clone() |
|
|
min_block_idx = i * num_blocks_per_rank |
|
|
block_indices_tmp -= min_block_idx |
|
|
block_indices_tmp[block_indices_tmp < 0] = num_blocks_per_rank |
|
|
|
|
|
block_indices_tmp, _ = torch.sort(block_indices_tmp, dim=-1) |
|
|
|
|
|
block_indices_tmp_lens = (block_indices_tmp < num_blocks_per_rank).sum(dim=-1) |
|
|
|
|
|
res.append([block_indices_tmp, block_indices_tmp_lens]) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def flash_attn_fwd_softmax_lse_correction( |
|
|
softmax_lse: torch.Tensor, |
|
|
softmax_lse_per_step: torch.Tensor, |
|
|
): |
|
|
"""Merge softmax stats of each step in Attention with context parallelism""" |
|
|
max_scale = torch.max(softmax_lse, softmax_lse_per_step) |
|
|
min_scale = torch.min(softmax_lse, softmax_lse_per_step) |
|
|
lse_diff = min_scale - max_scale |
|
|
lse_diff = lse_diff.nan_to_num(nan=0.) |
|
|
new_scale = max_scale + torch.log1p(torch.exp(lse_diff)) |
|
|
softmax_lse.copy_(new_scale) |
|
|
|
|
|
|
|
|
def flash_attn_fwd_out_correction_init( |
|
|
out_init_step: torch.Tensor, |
|
|
softmax_lse: torch.Tensor, |
|
|
softmax_lse_init_step: torch.Tensor, |
|
|
): |
|
|
"""Merge partial outputs of the first step in Attention with context parallelism""" |
|
|
softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse) |
|
|
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) |
|
|
out_corrected = out_init_step * softmax_lse_corrected_exp |
|
|
return out_corrected.to(out_init_step.dtype) |
|
|
|
|
|
|
|
|
|
|
|
def flash_attn_fwd_out_correction( |
|
|
out: torch.Tensor, |
|
|
out_per_step: torch.Tensor, |
|
|
softmax_lse: torch.Tensor, |
|
|
softmax_lse_per_step: torch.Tensor, |
|
|
): |
|
|
"""Merge partial outputs of each step in Attention with context parallelism""" |
|
|
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse) |
|
|
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) |
|
|
out_corrected = out_per_step * softmax_lse_corrected_exp |
|
|
out.add_(out_corrected) |
|
|
|
|
|
|
|
|
def topk_sort(score, num_chunks_selected): |
|
|
block_indices = torch.topk(score, num_chunks_selected)[1] |
|
|
block_indices, _ = torch.sort(block_indices, dim=-1) |
|
|
return block_indices |
|
|
|
|
|
class _attention_bsa(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, sm_scale, use_tma=False): |
|
|
|
|
|
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] |
|
|
|
|
|
HEAD_DIM_V = v.shape[-1] |
|
|
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V |
|
|
assert HEAD_DIM_K in {16, 32, 64, 128, 256} |
|
|
|
|
|
|
|
|
q_cmp = mean_pooling_compression(q, chunk_size_q) |
|
|
k_cmp = mean_pooling_compression(k, chunk_size_k) |
|
|
block_indices, block_indices_lens = get_select_indices(q_cmp, k_cmp, sparsity, cdf_threshold) |
|
|
|
|
|
|
|
|
|
|
|
o, lse = attn_fwd_bsa_varlen_triton( |
|
|
q, k, v, |
|
|
sm_scale, block_indices, block_indices_lens, |
|
|
chunk_size_q, chunk_size_k, |
|
|
sparsity |
|
|
) |
|
|
|
|
|
ctx.save_for_backward(q, k, v, o, lse, block_indices, block_indices_lens) |
|
|
ctx.sm_scale = sm_scale |
|
|
ctx.HEAD_DIM = HEAD_DIM_K |
|
|
ctx.chunk_size_q = chunk_size_q |
|
|
ctx.chunk_size_k = chunk_size_k |
|
|
ctx.use_tma = use_tma |
|
|
ctx.sparsity = sparsity |
|
|
|
|
|
return o |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, do): |
|
|
q, k, v, o, lse, block_indices, block_indices_lens = ctx.saved_tensors |
|
|
|
|
|
dq = torch.empty_like(q) |
|
|
dk = torch.empty_like(k) |
|
|
dv = torch.empty_like(v) |
|
|
|
|
|
attn_bwd_bsa_varlen_triton( |
|
|
do, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
o, |
|
|
dq, |
|
|
dk, |
|
|
dv, |
|
|
ctx.sm_scale, |
|
|
lse, |
|
|
block_indices, |
|
|
block_indices_lens, |
|
|
ctx.chunk_size_q, |
|
|
ctx.chunk_size_k, |
|
|
ctx.sparsity |
|
|
) |
|
|
|
|
|
return dq, dk, dv, None, None, None, None, None, None |
|
|
|
|
|
flash_attn_bsa = _attention_bsa.apply |
|
|
|
|
|
def rearrange_THW_to_3d_block(x, Nt, Nh, Nw, t, h, w, D): |
|
|
B, H, _, D = x.shape |
|
|
x = x.view(B, H, Nt, t, Nh, h, Nw, w, D) |
|
|
x = x.permute(0, 1, 2, 4, 6, 3, 5, 7, 8) |
|
|
return x.contiguous().view(B, H, Nt * Nh * Nw * t * h * w, D) |
|
|
|
|
|
def rearrange_3d_block_to_THW(x, Nt, Nh, Nw, t, h, w, D): |
|
|
B, H, _, D = x.shape |
|
|
x = x.view(B, H, Nt, Nh, Nw, t, h, w, D) |
|
|
x = x.permute(0, 1, 2, 5, 3, 6, 4, 7, 8) |
|
|
return x.contiguous().view(B, H, Nt * t * Nh * h * Nw * w, D) |
|
|
|
|
|
def flash_attn_bsa_3d( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
latent_shape_q, |
|
|
latent_shape_k, |
|
|
|
|
|
sparsity=0.875, |
|
|
cdf_threshold=None, |
|
|
chunk_3d_shape_q=[4, 4, 8], |
|
|
chunk_3d_shape_k=[4, 4, 8], |
|
|
) -> torch.Tensor: |
|
|
_, _, Sq, head_dim_q = q.shape |
|
|
_, _, Sk, head_dim_k = k.shape |
|
|
|
|
|
assert head_dim_q == head_dim_k |
|
|
head_dim = head_dim_q |
|
|
|
|
|
Tq, Hq, Wq = latent_shape_q |
|
|
Tk, Hk, Wk = latent_shape_k |
|
|
|
|
|
assert Tq * Hq * Wq == Sq |
|
|
assert Tk * Hk * Wk == Sk |
|
|
|
|
|
tq, hq, wq = chunk_3d_shape_q |
|
|
tk, hk, wk = chunk_3d_shape_k |
|
|
|
|
|
assert Tq % tq == 0 and Hq % hq == 0 and Wq % wq == 0 |
|
|
assert Tk % tk == 0 and Hk % hk == 0 and Wk % wk == 0 |
|
|
|
|
|
Ntq = Tq // tq |
|
|
Nhq = Hq // hq |
|
|
Nwq = Wq // wq |
|
|
|
|
|
Ntk = Tk // tk |
|
|
Nhk = Hk // hk |
|
|
Nwk = Wk // wk |
|
|
|
|
|
q = rearrange_THW_to_3d_block(q, Ntq, Nhq, Nwq, tq, hq, wq, q.shape[-1]) |
|
|
k = rearrange_THW_to_3d_block(k, Ntk, Nhk, Nwk, tk, hk, wk, k.shape[-1]) |
|
|
v = rearrange_THW_to_3d_block(v, Ntk, Nhk, Nwk, tk, hk, wk, v.shape[-1]) |
|
|
|
|
|
chunk_size_q = tq * hq * wq |
|
|
chunk_size_k = tk * hk * wk |
|
|
|
|
|
output = flash_attn_bsa(q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, 1 / head_dim**0.5) |
|
|
|
|
|
output = rearrange_3d_block_to_THW(output, Ntq, Nhq, Nwq, tq, hq, wq, output.shape[-1]) |
|
|
return output |
|
|
|