| | |
| | |
| |
|
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | from fla.ops.common.utils import prepare_chunk_indices |
| | from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard |
| |
|
| |
|
| | @triton.heuristics({ |
| | 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| | }) |
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BD': BD}, num_warps=num_warps) |
| | for BD in [16, 32, 64, 128] |
| | for num_warps in [1, 2, 4, 8] |
| | ], |
| | key=['BT'] |
| | ) |
| | @triton.jit(do_not_specialize=['T']) |
| | def mean_pooling_fwd_kernel( |
| | x, |
| | o, |
| | offsets, |
| | indices, |
| | T: tl.constexpr, |
| | H: tl.constexpr, |
| | D: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | NT: tl.constexpr, |
| | USE_OFFSETS: tl.constexpr |
| | ): |
| | i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| | i_b, i_h = i_bh // H, i_bh % H |
| | if USE_OFFSETS: |
| | i_tg = i_t |
| | i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
| | bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
| | T = eos - bos |
| | NT = tl.cdiv(T, BT) |
| | else: |
| | NT = tl.cdiv(T, BT) |
| | i_tg = i_b * NT + i_t |
| | bos, eos = i_b * T, i_b * T + T |
| |
|
| | p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
| | p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) |
| | |
| | b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
| | |
| | b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) |
| | tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
| |
|
| |
|
| | @triton.heuristics({ |
| | 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| | }) |
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BD': BD}, num_warps=num_warps) |
| | for BD in [16, 32, 64, 128] |
| | for num_warps in [1, 2, 4, 8] |
| | ], |
| | key=['BT'] |
| | ) |
| | @triton.jit(do_not_specialize=['T']) |
| | def mean_pooling_bwd_kernel( |
| | do, |
| | dx, |
| | offsets, |
| | indices, |
| | T: tl.constexpr, |
| | H: tl.constexpr, |
| | D: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | NT: tl.constexpr, |
| | USE_OFFSETS: tl.constexpr |
| | ): |
| | i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| | i_b, i_h = i_bh // H, i_bh % H |
| | if USE_OFFSETS: |
| | i_tg = i_t |
| | i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
| | bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
| | T = eos - bos |
| | NT = tl.cdiv(T, BT) |
| | else: |
| | NT = tl.cdiv(T, BT) |
| | i_tg = i_b * NT + i_t |
| | bos, eos = i_b * T, i_b * T + T |
| |
|
| | p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
| | p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) |
| | |
| | b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32) |
| | |
| | b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None] |
| | tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| |
|
| | def mean_pooling_fwd( |
| | x: torch.Tensor, |
| | chunk_size: int, |
| | offsets: Optional[torch.LongTensor] = None, |
| | indices: Optional[torch.LongTensor] = None |
| | ) -> torch.Tensor: |
| | B, T, H, D = x.shape |
| | BT = chunk_size |
| | NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
| |
|
| | o = x.new_empty(B, NT, H, D) |
| | def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) |
| | mean_pooling_fwd_kernel[grid]( |
| | x, |
| | o, |
| | offsets, |
| | indices, |
| | T=T, |
| | H=H, |
| | D=D, |
| | BT=BT, |
| | NT=NT, |
| | ) |
| | return o |
| |
|
| |
|
| | def mean_pooling_bwd( |
| | do: torch.Tensor, |
| | batch_size: int, |
| | seq_len: int, |
| | chunk_size: int, |
| | offsets: Optional[torch.LongTensor] = None, |
| | indices: Optional[torch.LongTensor] = None |
| | ) -> torch.Tensor: |
| | B, T, H, D = batch_size, seq_len, *do.shape[-2:] |
| | BT = chunk_size |
| | NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
| |
|
| | dx = do.new_empty(B, T, H, D) |
| | def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) |
| | mean_pooling_bwd_kernel[grid]( |
| | do, |
| | dx, |
| | offsets, |
| | indices, |
| | T=T, |
| | H=H, |
| | D=D, |
| | BT=BT, |
| | NT=NT |
| | ) |
| | return dx |
| |
|
| |
|
| | class MeanPoolingFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | @input_guard |
| | @autocast_custom_fwd |
| | def forward( |
| | ctx, |
| | x: torch.Tensor, |
| | chunk_size: int, |
| | offsets: Optional[torch.LongTensor] = None |
| | ) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None |
| | o = mean_pooling_fwd(x, chunk_size, offsets, indices) |
| | ctx.batch_size = x.shape[0] |
| | ctx.seq_len = x.shape[1] |
| | ctx.chunk_size = chunk_size |
| | ctx.offsets = offsets |
| | ctx.indices = indices |
| | return o |
| |
|
| | @staticmethod |
| | @input_guard |
| | @autocast_custom_bwd |
| | def backward( |
| | ctx, do |
| | ) -> Tuple[torch.Tensor, None, None]: |
| | batch_size = ctx.batch_size |
| | seq_len = ctx.seq_len |
| | chunk_size = ctx.chunk_size |
| | offsets = ctx.offsets |
| | indices = ctx.indices |
| | dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices) |
| | return dx, None, None |
| |
|
| |
|
| | def mean_pooling( |
| | x: torch.Tensor, |
| | chunk_size: int, |
| | cu_seqlens: Optional[torch.LongTensor] = None, |
| | head_first: bool = False |
| | ) -> torch.Tensor: |
| | if head_first: |
| | x = x.transpose(1, 2) |
| | if cu_seqlens is not None: |
| | if x.shape[0] != 1: |
| | raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." |
| | f"Please flatten variable-length inputs before processing.") |
| | o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) |
| | if head_first: |
| | o = o.transpose(1, 2) |
| | return o |
| |
|