vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import os
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import math
from .common import _attn_fwd_gating, _attn_bwd_preprocess, configs_gating_preset
from .flash_attn_bsa_varlen_mask import (
_attn_fwd_bsa_varlen, _attn_fwd_bsa_varlen_align, _attn_bwd_dkdv_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_align_wrapper,
configs_fwd_bsa_varlen_preset, configs_fwd_bsa_varlen_align_preset, configs_bwd_dkdv_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_align_preset
)
from .communicate import p2p_communicate
torch._dynamo.config.cache_size_limit = 32
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
# TmaAutoTuneHelper used in htyu's PR #5622
class TmaAutoTuneHelper:
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
else:
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
@triton.jit
def create_mask_from_indices_kernel(
block_indices,
block_mask,
stride_bz, stride_bh, stride_bm, stride_bs,
stride_mz, stride_mh, stride_mm, stride_mn,
H,
):
i_zh, i_m, i_s = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_z, i_h = i_zh // H, i_zh % H
off_b = i_z.to(tl.int64) * stride_bz + i_h.to(tl.int64) * stride_bh + i_m.to(tl.int64) * stride_bm + i_s.to(tl.int64) * stride_bs
b_i = tl.load(block_indices + off_b)
off_m = i_z.to(tl.int64) * stride_mz + i_h.to(tl.int64) * stride_mh + i_m.to(tl.int64) * stride_mm + b_i.to(tl.int64) * stride_mn
b_m = 1
tl.store(block_mask + off_m, b_m.to(block_mask.dtype.element_ty))
def create_mask_from_indices_triton(
block_indices,
N_cols
):
B, H, N_rows, S = block_indices.shape
block_mask = torch.zeros((B, H, N_rows, N_cols), dtype=torch.bool, device=block_indices.device)
create_mask_from_indices_kernel[(B * H, N_rows, S)](
block_indices,
block_mask,
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3),
H,
)
return block_mask
def create_mask_from_indices_varlen(block_indices, N_cols_mask):
B, H, M, _ = block_indices.shape
device = block_indices.device
mask = torch.zeros((B, H, M, N_cols_mask), dtype=torch.bool, device=device)
valid = block_indices < N_cols_mask
b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(block_indices)
h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(block_indices)
m_idx = torch.arange(M, device=device)[None, None, :, None].expand_as(block_indices)
valid_coords = (b_idx[valid], h_idx[valid], m_idx[valid], block_indices[valid])
mask[valid_coords] = True
return mask
def create_indices_k_from_indices_q_varlen(
block_indices,
N_cols_mask # indicate the number of the last dimension of the bool mask, since this information cannot be determined by block_indices, which may contain invalid elements
):
block_mask_qk = create_mask_from_indices_varlen(block_indices, N_cols_mask)
B, H, M, N = block_mask_qk.shape
block_mask_kq = block_mask_qk.permute(0, 1, 3, 2)
indices = torch.arange(M, device=block_indices.device).view(1, 1, 1, -1).expand_as(block_mask_kq)
block_indices_k = torch.where(block_mask_kq, indices, M)
block_indices_k, _ = torch.sort(block_indices_k, dim=-1)
block_indices_k_lens = (block_indices_k < M).sum(dim=-1)
return block_indices_k, block_indices_k_lens
def mean_pooling_compression(
x: torch.Tensor,
block_size: int
) -> torch.Tensor:
B, H, S = x.shape[:3]
num_block = math.ceil(S / block_size)
if S % block_size != 0:
x = F.pad(x, (0, 0, 0, num_block * block_size - S))
x_cmp = x.view(B, H, num_block, block_size, -1).mean(dim=3)
return x_cmp
def cal_score(q, k):
k_transposed = k.transpose(-1, -2) # [b, h, d, s_k]
score = torch.matmul(q, k_transposed) # [b, h, s_q, s_k]
return score
def cal_score_triton(q, k):
B, H, s_q, D = q.shape
s_k = k.shape[2]
score = torch.empty(B, H, s_q, s_k, device=q.device, dtype=q.dtype)
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_gating_preset['default']
grid = lambda args: (triton.cdiv(s_q, args["BLOCK_M"]), B * H, 1)
_attn_fwd_gating[grid](
q, k, score,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
score.stride(0), score.stride(1), score.stride(2), score.stride(3),
H, s_q, s_k,
HEAD_DIM=D,
**kernel_config
)
return score
def get_select_indices_topk(q, k, sparsity):
score = cal_score(q, k)
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
return block_indices, block_indices_lens
def get_select_indices_topk_from_score(score, sparsity):
num_selected = int((1 - sparsity) * score.shape[-1])
block_indices = torch.topk(score, num_selected)[1]
block_indices_lens = torch.full(
(block_indices.shape[0], block_indices.shape[1], block_indices.shape[2]),
num_selected,
dtype=torch.int32,
device=block_indices.device
)
return block_indices, block_indices_lens
def get_select_indices_cdf(q, k, cdf_threshold):
score = cal_score(q, k)
head_dim = q.shape[-1]
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold, 1 / head_dim**0.5)
return block_indices, block_indices_lens
def get_select_indices_cdf_from_score(score, cdf_threshold, sm_scale):
weights = torch.softmax(score * sm_scale, dim=-1)
B, H, Sq, Sk = weights.shape
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
weights_sorted = torch.sort(weights, dim=-1, descending=True)
cdf = torch.cumsum(weights_sorted.values, dim=-1)
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
return weights_sorted.indices, num_selected.squeeze(-1)
def get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold):
score = cal_score(q, k)
head_dim = q.shape[-1]
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, 1 / head_dim**0.5)
return block_indices, block_indices_lens
def get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, sm_scale):
weights = torch.softmax(score * sm_scale, dim=-1)
B, H, Sq, Sk = weights.shape
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
weights_sorted = torch.sort(weights, dim=-1, descending=True)
cdf = torch.cumsum(weights_sorted.values, dim=-1)
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
# max(cdf, topk)
num_selected_topk = int((1 - sparsity) * score.shape[-1])
num_selected[num_selected < num_selected_topk] = num_selected_topk
return weights_sorted.indices, num_selected.squeeze(-1)
def get_select_indices(q, k, sparsity, cdf_threshold):
if sparsity is not None and cdf_threshold is None:
block_indices, block_indices_lens = get_select_indices_topk(q, k, sparsity)
elif sparsity is None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf(q, k, cdf_threshold)
elif sparsity is not None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold)
else:
raise ValueError
return block_indices, block_indices_lens
def get_select_indices_from_score(score, sparsity, cdf_threshold):
if sparsity is not None and cdf_threshold is None:
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
elif sparsity is None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold)
elif sparsity is not None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold)
else:
raise ValueError
return block_indices, block_indices_lens
def attn_fwd_bsa_varlen_triton(
q,
k,
v,
sm_scale,
block_indices,
block_indices_lens,
chunk_size_q,
chunk_size_k,
sparsity
):
B, H, Seq, D = q.shape
o = torch.empty_like(q)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
config_key = 'BLOCK_N_LG=64' if chunk_size_k == 64 else 'default'
if chunk_size_k > 128:
fwd_func = _attn_fwd_bsa_varlen
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_preset[config_key]
else:
fwd_func = _attn_fwd_bsa_varlen_align
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_align_preset[config_key]
block_indices = block_indices.contiguous()
block_indices_lens = block_indices_lens.contiguous()
fwd_func[grid](
q, k, v, sm_scale, M, o,
block_indices, # [B, H, M_COMPRESS, S]
block_indices_lens, # [B, H, M_COMPRESS, S_MAX]
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
H, Seq,
D,
BLOCK_M=chunk_size_q,
BLOCK_N_LG=chunk_size_k,
SPARSITY=sparsity,
**kernel_config
)
LN2 = 0.6931471824645996
lse = M * LN2 # convert back to natural units (M is of base 2)
return o, lse
def attn_bwd_bsa_varlen_triton(
do,
q,
k,
v,
o,
dq,
dk,
dv,
sm_scale,
M,
block_indices,
block_indices_lens,
chunk_size_q,
chunk_size_k,
sparsity
):
RCP_LN2 = 1.4426950408889634
M = M * RCP_LN2 # ln -> log2
do = do.contiguous()
# assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
BATCH, N_HEAD, N_CTX, HEAD_DIM = q.shape
N_CTX_KV = k.shape[-2]
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) # reciprocal
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
if min(chunk_size_q, chunk_size_k) >= 128:
PRE_BLOCK = 128
else:
PRE_BLOCK = min(chunk_size_q, chunk_size_k)
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do,
delta,
N_CTX,
BLOCK_M=PRE_BLOCK,
HEAD_DIM=HEAD_DIM
)
block_indices_k, block_indices_k_lens = create_indices_k_from_indices_q_varlen(
block_indices=block_indices,
N_cols_mask=N_CTX_KV // chunk_size_k
)
block_indices = block_indices.contiguous()
block_indices_lens = block_indices_lens.contiguous()
block_indices_k = block_indices_k.contiguous()
block_indices_k_lens = block_indices_k_lens.contiguous()
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dkdv_bsa_varlen_preset[config_key]
grid_dkdv = lambda args: (triton.cdiv(arg_k.shape[2], args["BLOCK_N"]), 1, arg_k.shape[0] * arg_k.shape[1])
_attn_bwd_dkdv_bsa_varlen_wrapper[grid_dkdv](
q, arg_k, v, sm_scale, # softmax scale
do,
dk, dv,
M, # lse (log2)
delta,
block_indices_k,
block_indices_k_lens,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
M.stride(0), M.stride(1), M.stride(2),
delta.stride(0), delta.stride(1), delta.stride(2),
block_indices_k.stride(0), block_indices_k.stride(1), block_indices_k.stride(2), block_indices_k.stride(3),
block_indices_k_lens.stride(0), block_indices_k_lens.stride(1), block_indices_k_lens.stride(2),
N_HEAD, N_CTX,
BLOCK_M=chunk_size_q,
BLOCK_N_DQ_LG=chunk_size_k,
HEAD_DIM=HEAD_DIM,
SPARSITY=sparsity,
**kernel_config
)
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
if chunk_size_k > 128:
bwd_dq_func = _attn_bwd_dq_bsa_varlen_wrapper
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_preset[config_key]
else:
bwd_dq_func = _attn_bwd_dq_bsa_varlen_align_wrapper
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_align_preset[config_key]
grid_dq = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), 1, q.shape[0] * q.shape[1])
bwd_dq_func[grid_dq](
q, arg_k, v,
do,
dq,
M, # lse (log2)
delta,
block_indices,
block_indices_lens,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
M.stride(0), M.stride(1), M.stride(2),
delta.stride(0), delta.stride(1), delta.stride(2),
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
N_HEAD, N_CTX,
BLOCK_M=chunk_size_q,
BLOCK_N_DQ_LG=chunk_size_k,
HEAD_DIM=HEAD_DIM,
SPARSITY=sparsity,
**kernel_config
)
def make_block_indices_varlen_cp_list(block_indices, cp_size, num_blocks_k_full):
"""
Args:
block_indices: [B, H, num_blocks_q_per_cp_rank, num_blocks_k_full]
Return:
a list of [block_indices, block_indices_lens] for k from each cp_rank
- each block_indices starts from zero
- block_indices_lens indicates the valid number of elements in the last dimension of block_indices
"""
res = []
num_blocks_per_rank = num_blocks_k_full // cp_size
for i in range(cp_size):
block_indices_tmp = block_indices.clone()
min_block_idx = i * num_blocks_per_rank
block_indices_tmp -= min_block_idx
block_indices_tmp[block_indices_tmp < 0] = num_blocks_per_rank # block_indices_tmp < 0 indicate invalid indices, set them to num_blocks_per_rank in order to sort them to the tail, so that the first N elements of the block_indices indicated by block_indices_lens are valid
block_indices_tmp, _ = torch.sort(block_indices_tmp, dim=-1)
block_indices_tmp_lens = (block_indices_tmp < num_blocks_per_rank).sum(dim=-1)
res.append([block_indices_tmp, block_indices_tmp_lens])
return res
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
lse_diff = min_scale - max_scale
lse_diff = lse_diff.nan_to_num(nan=0.) # handle cases: tensor(-inf) - tensor(-inf) = tensor(nan); In the current cp implementation, it is possible that lses of 2 cp ranks are both -inf, if no block is selected from both cp ranks. In such cases, the finally corrected lse should remain -inf.
new_scale = max_scale + torch.log1p(torch.exp(lse_diff)) # a + ln(1 + e^(b - a)) = ln(e^a) + ln(1 + e^(b - a)) = ln(e^a + e^b)
softmax_lse.copy_(new_scale)
def flash_attn_fwd_out_correction_init(
out_init_step: torch.Tensor, # b h s d
softmax_lse: torch.Tensor, # b h s
softmax_lse_init_step: torch.Tensor,
):
"""Merge partial outputs of the first step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_init_step * softmax_lse_corrected_exp
return out_corrected.to(out_init_step.dtype)
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
def topk_sort(score, num_chunks_selected):
block_indices = torch.topk(score, num_chunks_selected)[1]
block_indices, _ = torch.sort(block_indices, dim=-1)
return block_indices
class _attention_bsa(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, sm_scale, use_tma=False):
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
# ---------------------- gating ----------------------
q_cmp = mean_pooling_compression(q, chunk_size_q)
k_cmp = mean_pooling_compression(k, chunk_size_k)
block_indices, block_indices_lens = get_select_indices(q_cmp, k_cmp, sparsity, cdf_threshold)
# ---------------------- bsa ----------------------
o, lse = attn_fwd_bsa_varlen_triton(
q, k, v,
sm_scale, block_indices, block_indices_lens,
chunk_size_q, chunk_size_k,
sparsity
)
ctx.save_for_backward(q, k, v, o, lse, block_indices, block_indices_lens)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.chunk_size_q = chunk_size_q
ctx.chunk_size_k = chunk_size_k
ctx.use_tma = use_tma
ctx.sparsity = sparsity
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse, block_indices, block_indices_lens = ctx.saved_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
attn_bwd_bsa_varlen_triton(
do,
q,
k,
v,
o,
dq,
dk,
dv,
ctx.sm_scale,
lse,
block_indices,
block_indices_lens,
ctx.chunk_size_q,
ctx.chunk_size_k,
ctx.sparsity
)
return dq, dk, dv, None, None, None, None, None, None
flash_attn_bsa = _attention_bsa.apply
def rearrange_THW_to_3d_block(x, Nt, Nh, Nw, t, h, w, D):
B, H, _, D = x.shape
x = x.view(B, H, Nt, t, Nh, h, Nw, w, D)
x = x.permute(0, 1, 2, 4, 6, 3, 5, 7, 8) # B H Nt Nh Nw t h w D
return x.contiguous().view(B, H, Nt * Nh * Nw * t * h * w, D)
def rearrange_3d_block_to_THW(x, Nt, Nh, Nw, t, h, w, D):
B, H, _, D = x.shape
x = x.view(B, H, Nt, Nh, Nw, t, h, w, D)
x = x.permute(0, 1, 2, 5, 3, 6, 4, 7, 8) # B H Nt t Nh h Nw w D
return x.contiguous().view(B, H, Nt * t * Nh * h * Nw * w, D)
def flash_attn_bsa_3d(
q: torch.Tensor, # [B, H, Sq, D]
k: torch.Tensor, # [B, H, Skv, D]
v: torch.Tensor, # [B, H, Skv, D]
latent_shape_q,
latent_shape_k,
# bsa_params
sparsity=0.875,
cdf_threshold=None,
chunk_3d_shape_q=[4, 4, 8],
chunk_3d_shape_k=[4, 4, 8],
) -> torch.Tensor:
_, _, Sq, head_dim_q = q.shape
_, _, Sk, head_dim_k = k.shape
assert head_dim_q == head_dim_k
head_dim = head_dim_q
Tq, Hq, Wq = latent_shape_q
Tk, Hk, Wk = latent_shape_k
assert Tq * Hq * Wq == Sq
assert Tk * Hk * Wk == Sk
tq, hq, wq = chunk_3d_shape_q
tk, hk, wk = chunk_3d_shape_k
assert Tq % tq == 0 and Hq % hq == 0 and Wq % wq == 0
assert Tk % tk == 0 and Hk % hk == 0 and Wk % wk == 0
Ntq = Tq // tq
Nhq = Hq // hq
Nwq = Wq // wq
Ntk = Tk // tk
Nhk = Hk // hk
Nwk = Wk // wk
q = rearrange_THW_to_3d_block(q, Ntq, Nhq, Nwq, tq, hq, wq, q.shape[-1])
k = rearrange_THW_to_3d_block(k, Ntk, Nhk, Nwk, tk, hk, wk, k.shape[-1])
v = rearrange_THW_to_3d_block(v, Ntk, Nhk, Nwk, tk, hk, wk, v.shape[-1])
chunk_size_q = tq * hq * wq
chunk_size_k = tk * hk * wk
output = flash_attn_bsa(q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, 1 / head_dim**0.5)
output = rearrange_3d_block_to_THW(output, Ntq, Nhq, Nwq, tq, hq, wq, output.shape[-1])
return output