| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from ...ops.utils import prepare_chunk_indices |
| from ...ops.utils.op import safe_exp |
|
|
|
|
| @triton.heuristics({ |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, |
| 'USE_G': lambda args: args['g_cumsum'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) |
| for BK in [32, 64, 128] |
| for num_warps in [2, 4, 8] |
| for num_stages in [2, 3, 4] |
| ], |
| key=['H', 'K', 'BT', 'IS_VARLEN'], |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_scaled_dot_kkt_fwd_kernel( |
| k, |
| beta, |
| g_cumsum, |
| A, |
| Ag, |
| cu_seqlens, |
| chunk_indices, |
| T, |
| H: tl.constexpr, |
| K: tl.constexpr, |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| IS_VARLEN: tl.constexpr, |
| USE_G: tl.constexpr, |
| ): |
| i_t, i_bh = tl.program_id(0), tl.program_id(1) |
| i_b, i_h = i_bh // H, i_bh % H |
| if IS_VARLEN: |
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
| T = eos - bos |
| else: |
| bos, eos = i_b * T, i_b * T + T |
| o_t = tl.arange(0, BT) |
|
|
| p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| b_beta = tl.load(p_beta, boundary_check=(0,)) |
|
|
| b_A = tl.zeros([BT, BT], dtype=tl.float32) |
| for i_k in range(tl.cdiv(K, BK)): |
| p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_kb = b_k * b_beta[:, None] |
| b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) |
|
|
| b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) |
| p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
| tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| if USE_G: |
| p_g = tl.make_block_ptr(g_cumsum + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
| b_g = tl.load(p_g, boundary_check=(0,)) |
| b_g_diff = b_g[:, None] - b_g[None, :] |
| b_Ag = b_A * safe_exp(b_g_diff) |
| p_Ag = tl.make_block_ptr(Ag + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
| tl.store(p_Ag, b_Ag.to(p_Ag.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| def chunk_scaled_dot_kkt_fwd( |
| k: torch.Tensor, |
| beta: torch.Tensor, |
| g_cumsum: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.LongTensor] = None, |
| chunk_size: int = 64, |
| output_dtype: torch.dtype = torch.float32 |
| ) -> torch.Tensor: |
| r""" |
| Compute beta * K * K^T. |
| |
| Args: |
| k (torch.Tensor): |
| The key tensor of shape `[B, T, H, K]`. |
| beta (torch.Tensor): |
| The beta tensor of shape `[B, T, H]`. |
| g_cumsum (torch.Tensor): |
| The cumulative sum of the gate tensor of shape `[B, T, H]`. |
| Default: None |
| cu_seqlens (torch.LongTensor): |
| The cumulative sequence lengths of the input tensor. |
| Default: None |
| chunk_size (int): |
| The chunk size. Default: 64. |
| output_dtype (torch.dtype): |
| The dtype of the output tensor. Default: `torch.float32` |
| |
| Returns: |
| beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. |
| """ |
| B, T, H, K = k.shape |
| BT = chunk_size |
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None |
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) |
| A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) |
| Ag = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) if g_cumsum is not None else None |
| chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( |
| k=k, |
| beta=beta, |
| g_cumsum=g_cumsum, |
| A=A, |
| Ag=Ag, |
| cu_seqlens=cu_seqlens, |
| chunk_indices=chunk_indices, |
| T=T, |
| H=H, |
| K=K, |
| BT=BT, |
| ) |
| return A, Ag |
|
|