| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from fla.utils import check_shared_mem, input_guard |
|
|
| BS_LIST = [32, 64] if check_shared_mem() else [16, 32] |
|
|
|
|
| @triton.heuristics({ |
| 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps) |
| for num_warps in [1, 2, 4, 8] |
| ], |
| key=['BT'] |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_local_cumsum_scalar_kernel( |
| s, |
| o, |
| offsets, |
| indices, |
| T, |
| H: tl.constexpr, |
| BT: tl.constexpr, |
| HEAD_FIRST: tl.constexpr, |
| USE_OFFSETS: tl.constexpr, |
| REVERSE: tl.constexpr |
| ): |
| i_t, i_bh = tl.program_id(0), tl.program_id(1) |
| i_b, i_h = i_bh // H, i_bh % H |
| if USE_OFFSETS: |
| i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
| bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
| T = eos - bos |
| else: |
| bos, eos = i_b * T, i_b * T + T |
|
|
| if HEAD_FIRST: |
| p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| else: |
| p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| |
| b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) |
| b_o = tl.cumsum(b_s, axis=0) |
| if REVERSE: |
| b_z = tl.sum(b_s, axis=0) |
| b_o = -b_o + b_z[None] + b_s |
| tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
|
|
|
|
| @triton.heuristics({ |
| 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BS': BS}, num_warps=num_warps) |
| for BS in BS_LIST |
| for num_warps in [2, 4, 8] |
| ], |
| key=['S', 'BT'], |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_local_cumsum_vector_kernel( |
| s, |
| o, |
| offsets, |
| indices, |
| T, |
| H: tl.constexpr, |
| S: tl.constexpr, |
| BT: tl.constexpr, |
| BS: tl.constexpr, |
| HEAD_FIRST: tl.constexpr, |
| USE_OFFSETS: tl.constexpr, |
| REVERSE: tl.constexpr |
| ): |
| i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| i_b, i_h = i_bh // H, i_bh % H |
| if USE_OFFSETS: |
| i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
| bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
| T = eos - bos |
| else: |
| bos, eos = i_b * T, i_b * T + T |
|
|
| o_i = tl.arange(0, BT) |
| if REVERSE: |
| m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) |
| else: |
| m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) |
|
|
| if HEAD_FIRST: |
| p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| else: |
| p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| |
| b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) |
| b_o = tl.dot(m_s, b_s, allow_tf32=False) |
| tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| @triton.heuristics({ |
| 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BT': 16}, num_warps=2), |
| triton.Config({'BT': 32}, num_warps=4), |
| triton.Config({'BT': 32}, num_warps=2), |
| triton.Config({'BT': 64}, num_warps=8), |
| triton.Config({'BT': 64}, num_warps=4), |
| ], |
| key=[] |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_global_cumsum_scalar_kernel( |
| s, |
| o, |
| offsets, |
| T, |
| H: tl.constexpr, |
| BT: tl.constexpr, |
| HEAD_FIRST: tl.constexpr, |
| USE_OFFSETS: tl.constexpr, |
| REVERSE: tl.constexpr |
| ): |
| i_bh = tl.program_id(0) |
| i_b, i_h = i_bh // H, i_bh % H |
| if USE_OFFSETS: |
| bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) |
| else: |
| bos, eos = i_b * T, i_b * T + T |
| T = eos - bos |
|
|
| b_z = tl.zeros([], dtype=tl.float32) |
| NT = tl.cdiv(T, BT) |
| for i_c in range(NT): |
| i_t = NT-1-i_c if REVERSE else i_c |
| if HEAD_FIRST: |
| p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| else: |
| p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) |
| b_o = tl.cumsum(b_s, axis=0) |
| b_ss = tl.sum(b_s, 0) |
| if REVERSE: |
| b_o = -b_o + b_ss + b_s |
| b_o += b_z |
| if i_c >= 0: |
| b_z += b_ss |
| tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
|
|
|
|
| @triton.heuristics({ |
| 'USE_OFFSETS': lambda args: args['offsets'] is not None, |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BT': BT}, num_warps=num_warps) |
| for BT in [16, 32, 64] |
| for num_warps in [2, 4, 8] |
| ], |
| key=['S'] |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_global_cumsum_vector_kernel( |
| s, |
| z, |
| offsets, |
| T, |
| H: tl.constexpr, |
| S: tl.constexpr, |
| BT: tl.constexpr, |
| BS: tl.constexpr, |
| HEAD_FIRST: tl.constexpr, |
| USE_OFFSETS: tl.constexpr, |
| REVERSE: tl.constexpr |
| ): |
| i_s, i_bh = tl.program_id(0), tl.program_id(1) |
| i_b, i_h = i_bh // H, i_bh % H |
| if USE_OFFSETS: |
| bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) |
| else: |
| bos, eos = i_b * T, i_b * T + T |
| T = eos - bos |
|
|
| o_i = tl.arange(0, BT) |
| if REVERSE: |
| m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) |
| else: |
| m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) |
|
|
| b_z = tl.zeros([BS], dtype=tl.float32) |
| NT = tl.cdiv(T, BT) |
| for i_c in range(NT): |
| i_t = NT-1-i_c if REVERSE else i_c |
| if HEAD_FIRST: |
| p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| else: |
| p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
| |
| b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) |
| b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) |
| tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) |
| if i_c >= 0: |
| b_z += tl.sum(b_s, 0) |
|
|
|
|
| def chunk_local_cumsum_scalar( |
| g: torch.Tensor, |
| chunk_size: int, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| if head_first: |
| B, H, T = g.shape |
| else: |
| B, T, H = g.shape |
| if offsets is not None: |
| B = len(offsets) - 1 |
| assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" |
| BT = chunk_size |
| NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
| g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) |
| grid = (NT, B * H) |
| chunk_local_cumsum_scalar_kernel[grid]( |
| g_org, |
| g, |
| offsets, |
| indices, |
| T=T, |
| H=H, |
| BT=BT, |
| HEAD_FIRST=head_first, |
| REVERSE=reverse |
| ) |
| return g |
|
|
|
|
| def chunk_local_cumsum_vector( |
| g: torch.Tensor, |
| chunk_size: int, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| if head_first: |
| B, H, T, S = g.shape |
| else: |
| B, T, H, S = g.shape |
| BT = chunk_size |
| NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
| assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" |
|
|
| g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) |
| def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) |
| |
| |
| |
| chunk_local_cumsum_vector_kernel[grid]( |
| g_org, |
| g, |
| offsets, |
| indices, |
| T=T, |
| H=H, |
| S=S, |
| BT=BT, |
| HEAD_FIRST=head_first, |
| REVERSE=reverse |
| ) |
| return g |
|
|
|
|
| @input_guard |
| def chunk_global_cumsum_scalar( |
| s: torch.Tensor, |
| dtype: Optional[torch.dtype] = None, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| dtype = dtype or s.dtype |
| if head_first: |
| B, H, T = s.shape |
| else: |
| B, T, H = s.shape |
| if offsets is not None: |
| B = len(offsets) - 1 |
| grid = (B * H,) |
| z = torch.empty_like(s, dtype=output_dtype or dtype) |
| chunk_global_cumsum_scalar_kernel[grid]( |
| s, |
| z, |
| offsets, |
| T=T, |
| H=H, |
| HEAD_FIRST=head_first, |
| REVERSE=reverse |
| ) |
| return z |
|
|
|
|
| @input_guard |
| def chunk_global_cumsum_vector( |
| s: torch.Tensor, |
| dtype: Optional[torch.dtype] = None, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| dtype = dtype or s.dtype |
| if head_first: |
| B, H, T, S = s.shape |
| else: |
| B, T, H, S = s.shape |
| BS = min(32, triton.next_power_of_2(S)) |
| if offsets is not None: |
| B = len(offsets) - 1 |
| grid = (triton.cdiv(S, BS), B * H) |
| z = torch.empty_like(s, dtype=output_dtype or dtype) |
| chunk_global_cumsum_vector_kernel[grid]( |
| s, |
| z, |
| offsets, |
| T=T, |
| H=H, |
| S=S, |
| BS=BS, |
| HEAD_FIRST=head_first, |
| REVERSE=reverse |
| ) |
| return z |
|
|
|
|
| @input_guard |
| def chunk_global_cumsum( |
| s: torch.Tensor, |
| dtype: Optional[torch.dtype] = None, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| if offsets is not None: |
| assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" |
| if len(s.shape) == 3: |
| return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype) |
| elif len(s.shape) == 4: |
| return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype) |
| else: |
| raise ValueError(f"Unsupported input shape {s.shape}. " |
| f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " |
| f"or [B, T, H]/[B, T, H, D] otherwise") |
|
|
|
|
| @input_guard |
| def chunk_local_cumsum( |
| g: torch.Tensor, |
| chunk_size: int, |
| reverse: bool = False, |
| offsets: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| head_first: bool = True, |
| output_dtype: Optional[torch.dtype] = torch.float |
| ) -> torch.Tensor: |
| if offsets is not None: |
| assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" |
| if len(g.shape) == 3: |
| return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) |
| elif len(g.shape) == 4: |
| return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) |
| else: |
| raise ValueError(f"Unsupported input shape {g.shape}. " |
| f"which should be (B, H, T, dim) if `head_first=True` " |
| f"or (batch_size, num_heads, seq_len) otherwise") |
|
|