| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from fla.ops.common.utils import prepare_chunk_indices |
|
|
|
|
| @triton.heuristics({ |
| 'USE_OFFSETS': lambda args: args['offsets'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) |
| for BK in [32, 64, 128] |
| for num_warps in [2, 4, 8] |
| for num_stages in [2, 3, 4] |
| ], |
| key=['H', 'K', 'BT', 'USE_OFFSETS'], |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_scaled_dot_kkt_fwd_kernel( |
| k, |
| beta, |
| A, |
| offsets, |
| indices, |
| T, |
| H: tl.constexpr, |
| K: tl.constexpr, |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| HEAD_FIRST: tl.constexpr, |
| USE_OFFSETS: tl.constexpr, |
| ): |
| i_t, i_bh = tl.program_id(0), tl.program_id(1) |
| i_b, i_h = i_bh // H, i_bh % H |
| if USE_OFFSETS: |
| i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
| bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
| T = eos - bos |
| else: |
| bos, eos = i_b * T, i_b * T + T |
| o_t = tl.arange(0, BT) |
|
|
| if HEAD_FIRST: |
| p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| else: |
| p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| b_beta = tl.load(p_beta, boundary_check=(0,)) |
|
|
| b_A = tl.zeros([BT, BT], dtype=tl.float32) |
| for i_k in range(tl.cdiv(K, BK)): |
| if HEAD_FIRST: |
| p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| else: |
| p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_kb = b_k * b_beta[:, None] |
| b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) |
|
|
| b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) |
| if HEAD_FIRST: |
| p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
| else: |
| p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
| tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| def chunk_scaled_dot_kkt_fwd( |
| k: torch.Tensor, |
| beta: torch.Tensor, |
| cu_seqlens: Optional[torch.LongTensor], |
| head_first: bool = False, |
| chunk_size: int = 64, |
| output_dtype: torch.dtype = torch.float32 |
| ) -> torch.Tensor: |
| r""" |
| Compute beta * K * K^T. |
| |
| Args: |
| k (torch.Tensor): |
| The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`. |
| beta (torch.Tensor): |
| The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`. |
| cu_seqlens (torch.LongTensor): |
| The cumulative sequence lengths of the input tensor. |
| Default: None |
| head_first (bool): |
| If False, the input/output tensor is in the shape of `[B, T, H, K]`. |
| If True, the input/output tensor is in the shape of `[B, H, T, K]`. |
| Default: False |
| chunk_size (int): |
| The chunk size. Default: 64. |
| output_dtype (torch.dtype): |
| The dtype of the output tensor. Default: `torch.float32` |
| |
| Returns: |
| beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`, |
| where `BT` is the chunk size. |
| """ |
| if head_first: |
| B, H, T, K = k.shape |
| else: |
| B, T, H, K = k.shape |
| BT = chunk_size |
| indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None |
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) |
| A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype) |
| chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( |
| k=k, |
| beta=beta, |
| A=A, |
| offsets=cu_seqlens, |
| indices=indices, |
| T=T, |
| H=H, |
| K=K, |
| BT=BT, |
| HEAD_FIRST=head_first |
| ) |
| return A |
|
|