GLM2NSA / compressed_attention.py
Maxtimer97's picture
Added num stages tuning
42f4907
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, Tuple, Union
import torch
import triton
import triton.language as tl
try:
from .utils import get_num_warps_stages, is_hopper_gpu
except ImportError:
from ops.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [1, 2, 3]
],
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_V'],
)
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# size and stride at compresstion
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
# attention
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, HEAD_DIM),
strides=(stride_on, stride_od),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'],
)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, HEAD_DIM),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, HEAD_DIM),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(HEAD_DIM, q_len),
strides=(stride_qd, stride_qn),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(HEAD_DIM, q_len),
strides=(stride_dod, stride_don),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(1, q_len),
strides=(0, stride_dn),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(1, q_len),
strides=(0, stride_ln),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
qk += tl.dot(k, q) * qk_scale
# compute p, ds
# [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
p = tl.exp2(qk - lse)
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
dp = tl.dot(v, do)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
# [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
dk += tl.dot(ds, tl.trans(q))
dv += tl.dot(p, tl.trans(do))
# increment pointers
q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
)
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
shape=(q_len, HEAD_DIM),
strides=(stride_dqn, stride_dqd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(HEAD_DIM, k_len),
strides=(stride_vd, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, HEAD_DIM),
strides=(stride_don, stride_dod),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
# load
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
qk += tl.dot(q, tl.trans(k)) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# increment pointers
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _compressed_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale: float,
):
# dtype check
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert k_len == v_len and q_len >= k_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.zeros_like(q)
lse = torch.full(
(num_q_heads, q_len),
fill_value=-torch.inf,
dtype=torch.float32,
device=q.device,
)
# launch kernel
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
forward_kernel[grid](
q,
k,
v,
o,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
# num_warps=num_warps,
# num_stages=num_stages,
)
return o, lse
def _compressed_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale: float,
):
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
o_len, num_o_heads, head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
# compute D
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
BLOCK_SIZE_O = 256
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_D,
# num_warps=num_warps,
# num_stages=num_stages,
)
# compute dk dv
dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
batch_size = cu_seqlens_q.shape[0] - 1
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
backward_dkdv[grid](
q,
k,
v,
lse,
delta,
do,
dk,
dv,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
# num_warps=num_warps,
# num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.zeros_like(q)
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 64
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
backward_dq[grid](
q,
k,
v,
lse,
delta,
do,
dq,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
# num_warps=num_warps,
# num_stages=num_stages,
)
return dq, dk, dv
class CompressedAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale=None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype and k.dtype == v.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _compressed_attention_fwd(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
return o, lse
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
dq, dk, dv = _compressed_attention_bwd(
o,
do,
lse,
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
return dq, dk, dv, None, None, None, None, None, None, None
@triton.jit
def score_kernel(
q_ptr,
k_ptr,
lse_ptr,
s_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_lh,
stride_ln,
stride_sh,
stride_sq,
stride_sk,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_bkh = tl.program_id(0)
pid_b = pid_bkh // NUM_KV_HEADS
pid_kh = pid_bkh % NUM_KV_HEADS
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
return
# init k pointer and load k
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
# init score
s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_kh * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# load q and lse
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k) * qk_scale
# compute score
s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
# save output
s_ptrs = tl.make_block_ptr(
base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
shape=(q_len, k_len),
strides=(stride_sq, stride_sk),
offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
order=(1, 0),
)
tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
def _get_attention_score(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
lse: torch.Tensor, # [num_q_heads, total_query_len]
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
) -> torch.Tensor:
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert q_len >= k_len
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# gqa
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# init score
score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device)
# launch kernel
grid = lambda META: (
batch_size * num_k_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
score_kernel[grid](
q,
k,
lse,
score,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
lse.stride(0),
lse.stride(1),
score.stride(0),
score.stride(1),
score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
# num_warps=8,
# num_stages=3,
)
return score
@triton.jit
def _transform_score_kernel(
s_ptr, # score, shape: [num_heads, q_len, k_len]
bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
offs,
cu_seqlens_q,
# shape
num_heads,
num_offs,
max_k_len,
max_blocks,
pad_len,
# kernel & block size
block_size,
block_stride, # block_size // kernel_stride
init_blocks,
local_blocks,
# stride
stride_sh,
stride_sq,
stride_sk,
stride_bsh,
stride_bsq,
stride_bsk,
TOTAL_QUERY_LEN: tl.constexpr,
BLOCK_SIZE_Q: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_O: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // num_heads
pid_h = pid_bh % num_heads
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = pid_k * BLOCK_SIZE_K
if pid_q * BLOCK_SIZE_Q >= q_len:
return
# load weight
off_o = tl.arange(0, BLOCK_SIZE_O)
w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
# load score
off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
off_k = off_k[None, :] + off_o[:, None]
s_ptrs = (
s_ptr
+ q_start * stride_sq
+ pid_h * stride_sh
+ off_q[:, None, None] * stride_sq
+ off_k[None, :, :] * stride_sk
)
# weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
s = tl.load(
s_ptrs,
mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
other=0,
)
s = s * w[None, :, None]
s = tl.sum(s, axis=1)
# init mask and local mask
off_bq = off_q // block_size
off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
s = tl.where(
((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks))
| (off_bk[None, :] < init_blocks - k_start),
float("inf"),
s,
)
# store block wise score
bs_ptrs = (
bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk
)
tl.store(
bs_ptrs,
s,
mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
)
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, max_key_len = score.shape
batch_size = cu_seqlens_q.shape[0] - 1
pad_len = kernel_size // kernel_stride - 1
max_blocks = math.ceil(max_seqlen_q / block_size)
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
+ torch.arange(block_size // kernel_stride, device=score.device)[None, :]
).view(-1)
offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
num_offs = int(offs.shape[0])
BLOCK_SIZE_Q = 16
BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
def grid(meta):
grid = (
num_k_heads * batch_size,
triton.cdiv(total_query_len, BLOCK_SIZE_Q),
triton.cdiv(max_blocks, BLOCK_SIZE_K),
)
return grid
_transform_score_kernel[grid](
score,
block_score,
offs,
cu_seqlens_q,
num_k_heads,
offs.shape[0],
max_key_len,
max_blocks,
pad_len,
block_size,
block_size // kernel_stride,
init_blocks,
local_blocks,
score.stride(0),
score.stride(1),
score.stride(2),
block_score.stride(0),
block_score.stride(1),
block_score.stride(2),
TOTAL_QUERY_LEN=total_query_len,
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_O=BLOCK_SIZE_O,
# num_warps=4,
# num_stages=3,
)
return block_score
def compressed_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
parallel_topk_compute: Union[str, bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch.
max_seqlen_k (int): max k len of the batch.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
if max_seqlen_q is None:
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
if max_seqlen_k is None:
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
attn_output, lse = CompressedAttention.apply(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# do not select topk index
if topk <= 0:
warnings.warn("topk <= 0, returned topk_idx will be None")
return attn_output, None
assert topk >= init_blocks + local_blocks
with torch.no_grad():
num_k_heads, num_q_heads = k.shape[1], q.shape[1]
num_shared_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)],
dim=0,
)
q_idx = q_idx // block_size
# whether to use parallel version
if parallel_topk_compute == "auto":
parallel_topk_compute = cu_seqlens_q[-1] <= 32768
# parallel version
if parallel_topk_compute:
# recompute score
score = _get_attention_score(
q,
k,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
# non parallel version, avoid some current bugs when sequence length is too long
# FIXME: need to fix later
else:
topk_idx_list = []
head_tile = 1
assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}"
for h in range(num_k_heads // head_tile):
# recompute score
score = _get_attention_score(
q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
k[:, h * head_tile: (h + 1) * head_tile],
lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
if score.dtype == torch.float32:
score = score.to(torch.bfloat16)
topk_idx = score.topk(topk, dim=-1, sorted=False).indices
topk_idx = topk_idx.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
topk_idx_list.append(topk_idx)
topk_idx = torch.cat(topk_idx_list, dim=0)
return attn_output, topk_idx