base_IIXIV / fla /ops /comba /utils.py
mainline777's picture
Duplicate from silx-ai/Quasar-Preview
41865df
Raw
History Blame Contribute Delete
5 kB
import torch
import triton
import triton.language as tl
from fla.ops.utils.index import prepare_chunk_indices
from fla.utils import autotune_cache_kwargs
@triton.heuristics({
'HAS_SCALE': lambda args: args['scale'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['B', 'H', 'BT', 'IS_VARLEN'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def chunk_comba_cumsum_scalar_fwd_kernel(
g,
g0,
g1,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_g0 = tl.make_block_ptr(g0 + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_g1 = tl.make_block_ptr(g1 + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
# [BT]
b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
if HAS_SCALE:
b_g = b_g * scale
b_g1 = tl.cumsum(b_g, axis=0)
b_g0 = b_g1 - b_g
tl.store(p_g0, b_g0.to(p_g0.dtype.element_ty), boundary_check=(0,))
tl.store(p_g1, b_g1.to(p_g1.dtype.element_ty), boundary_check=(0,))
def chunk_comba_cumsum_scalar_fwd(
g: torch.Tensor,
chunk_size: int,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = torch.float,
chunk_indices: torch.LongTensor | None = None,
scale: float | None = None,
) -> torch.Tensor:
B, T, H = g.shape
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g0, g1 = torch.empty_like(g, dtype=output_dtype or g.dtype), torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_comba_cumsum_scalar_fwd_kernel[grid](
g,
g0,
g1,
scale,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
)
return g0, g1
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['B', 'H', 'BT', 'IS_VARLEN'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def chunk_comba_cumsum_scalar_bwd_kernel(
dg0,
dgr,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_dg0 = tl.make_block_ptr(dg0 + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_dgr = tl.make_block_ptr(dgr + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
# [BT]
"""
b_dg: 1,2,3,4
b_dg0: 0,1,2,3
b_temp: 0,1,3,6
b_dz: 6
b_dgr: 6,5,3,0
"""
b_dg0 = tl.load(p_dg0, boundary_check=(0,)).to(tl.float32)
b_temp = tl.cumsum(b_dg0, axis=0)
b_dz = tl.sum(b_dg0, axis=0)
b_dgr = -b_temp + b_dz[None]
tl.store(p_dgr, b_dgr.to(p_dgr.dtype.element_ty), boundary_check=(0,))
def chunk_comba_cumsum_scalar_bwd(
dg0: torch.Tensor,
chunk_size: int,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = torch.float,
chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
B, T, H = dg0.shape
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
dg = torch.empty_like(dg0, dtype=output_dtype or dg0.dtype)
grid = (NT, B * H)
chunk_comba_cumsum_scalar_bwd_kernel[grid](
dg0,
dg,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
)
return dg