dsmtp-code-7B-4096-model / fla /ops /common /chunk_scaled_dot_kkt.py
zaydzuhri's picture
Add files using upload-large-folder tool
0d27699 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_chunk_indices
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'BT', 'USE_OFFSETS'],
)
@triton.jit(do_not_specialize=['T'])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
A,
offsets,
indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
HEAD_FIRST: tl.constexpr,
USE_OFFSETS: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = tl.arange(0, BT)
if HEAD_FIRST:
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
if HEAD_FIRST:
p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
else:
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
if HEAD_FIRST:
p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
else:
p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
head_first: bool = False,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
head_first (bool):
If False, the input/output tensor is in the shape of `[B, T, H, K]`.
If True, the input/output tensor is in the shape of `[B, H, T, K]`.
Default: False
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
where `BT` is the chunk size.
"""
if head_first:
B, H, T, K = k.shape
else:
B, T, H, K = k.shape
BT = chunk_size
indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
k=k,
beta=beta,
A=A,
offsets=cu_seqlens,
indices=indices,
T=T,
H=H,
K=K,
BT=BT,
HEAD_FIRST=head_first
)
return A