Kernels
sigmoid-neuron's picture
feat: add support for ROCm backend and update device validation for HIP and XPU compatibility
a391959
"""
Flash Attention 1 — Triton Kernel
==================================
Pure Python/Triton implementation of Flash Attention (Dao et al., 2022).
Reference: https://arxiv.org/abs/2205.14135
Algorithm overview
------------------
Flash Attention fuses the softmax(Q Kᵀ / √d) · V computation into a single
GPU kernel pass using online softmax (Milakov & Gimelshein, 2018) so that the
full N×N attention matrix is never materialised in HBM.
This file contains:
1. `_flash_attn_fwd_kernel` forward pass Triton kernel
2. `_flash_attn_bwd_kernel` backward pass Triton kernel (stub, TODO)
3. `flash_attention_forward` Python launcher / autograd function wrapper
"""
import math
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Forward kernel
# ---------------------------------------------------------------------------
@triton.jit
def _flash_attn_fwd_kernel(
# --- pointers ---
Q_ptr, K_ptr, V_ptr, # [B, H, N, d_head]
O_ptr, # [B, H, N, d_head] output
L_ptr, # [B, H, N] log-sum-exp (for bwd)
# --- strides (Q, K, V, O share layout) ---
stride_qb, stride_qh, stride_qn, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_ob, stride_oh, stride_on, stride_od,
# --- problem dims ---
N: tl.constexpr, # sequence length
d_head: tl.constexpr, # head dimension
H: tl.constexpr, # number of heads
CAUSAL: tl.constexpr, # whether to apply causal mask
scale, # 1 / sqrt(d_head)
# --- tile sizes ---
BLOCK_M: tl.constexpr, # rows of Q processed per CTA
BLOCK_N: tl.constexpr, # cols of K/V tile
):
"""
Each program handles one (batch, head, tile-of-rows) triple.
Inner loop tiles over K/V columns and accumulates running O, m, l
using the online-softmax update from the Flash Attention paper.
"""
# -----------------------------------------------------------------------
# Identify this program's slice
# -----------------------------------------------------------------------
start_m = tl.program_id(0) # tile index along N (row dimension)
off_bh = tl.program_id(1) # flat batch×head index
off_b = off_bh // H
off_h = off_bh % H
# Row offsets for this CTA
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BLOCK_M]
offs_d = tl.arange(0, d_head) # [d_head]
# Base pointer for Q rows assigned to this CTA
Q_blk_ptr = (
Q_ptr
+ off_b * stride_qb
+ off_h * stride_qh
+ offs_m[:, None] * stride_qn # [BLOCK_M, 1]
+ offs_d[None, :] * stride_qd # [1, d_head]
)
# -----------------------------------------------------------------------
# Load Q tile [BLOCK_M, d_head]
# -----------------------------------------------------------------------
q = tl.load(Q_blk_ptr, mask=offs_m[:, None] < N, other=0.0)
# -----------------------------------------------------------------------
# Running accumulators (online softmax)
# -----------------------------------------------------------------------
m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) # row max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # normaliser
o_i = tl.zeros([BLOCK_M, d_head], dtype=tl.float32) # output acc
# -----------------------------------------------------------------------
# Iterate over K/V tiles
# -----------------------------------------------------------------------
for start_n in range(0, N, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N) # [BLOCK_N]
# Load K tile [d_head, BLOCK_N] (transposed for matmul)
K_blk_ptr = (
K_ptr
+ off_b * stride_kb
+ off_h * stride_kh
+ offs_n[None, :] * stride_kn
+ offs_d[:, None] * stride_kd
)
k = tl.load(K_blk_ptr, mask=offs_n[None, :] < N, other=0.0)
# Load V tile [BLOCK_N, d_head]
V_blk_ptr = (
V_ptr
+ off_b * stride_vb
+ off_h * stride_vh
+ offs_n[:, None] * stride_vn
+ offs_d[None, :] * stride_vd
)
v = tl.load(V_blk_ptr, mask=offs_n[:, None] < N, other=0.0)
# --- QKᵀ scaled dot-product [BLOCK_M, BLOCK_N] ---
s = tl.dot(q, k) * scale
if CAUSAL:
s = tl.where(offs_m[:, None] >= offs_n[None, :], s, float("-inf"))
# --- online-softmax update ---
m_ij = tl.max(s, axis=1) # [BLOCK_M]
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new) # rescale old acc
p = tl.exp(s - m_new[:, None]) # [BLOCK_M, BLOCK_N]
l_new = alpha * l_i + tl.sum(p, axis=1) # [BLOCK_M]
o_i = alpha[:, None] * o_i + tl.dot(p.to(tl.float16), v)
m_i = m_new
l_i = l_new
# -----------------------------------------------------------------------
# Normalise and write output
# -----------------------------------------------------------------------
o_i = o_i / l_i[:, None]
O_blk_ptr = (
O_ptr
+ off_b * stride_ob
+ off_h * stride_oh
+ offs_m[:, None] * stride_on
+ offs_d[None, :] * stride_od
)
tl.store(O_blk_ptr, o_i.to(tl.float16), mask=offs_m[:, None] < N)
# Store log-sum-exp for backward pass
L_blk_ptr = L_ptr + off_b * H * N + off_h * N + offs_m
lse = m_i + tl.log(l_i)
tl.store(L_blk_ptr, lse, mask=offs_m < N)
# ---------------------------------------------------------------------------
# Backward kernels
# ---------------------------------------------------------------------------
@triton.jit
def _attn_bwd_preprocess(
# pointers
O_ptr, DO_ptr, D_ptr,
# strides
stride_ob, stride_oh, stride_on, stride_od,
# dims
BLOCK_M: tl.constexpr, d_head: tl.constexpr, H: tl.constexpr, N: tl.constexpr,
):
"""
Precomputes D = sum(O * dO, axis=-1) for the backward pass.
"""
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_bh = tl.program_id(1)
off_b = off_bh // H
off_h = off_bh % H
off_d = tl.arange(0, d_head)
# Load O and dO
o_ptrs = O_ptr + off_b * stride_ob + off_h * stride_oh + off_m[:, None] * stride_on + off_d[None, :] * stride_od
do_ptrs = DO_ptr + off_b * stride_ob + off_h * stride_oh + off_m[:, None] * stride_on + off_d[None, :] * stride_od
o = tl.load(o_ptrs).to(tl.float32)
do = tl.load(do_ptrs).to(tl.float32)
delta = tl.sum(o * do, axis=1) # [BLOCK_M]
d_ptrs = D_ptr + off_b * H * N + off_h * N + off_m
tl.store(d_ptrs, delta, mask=off_m < N)
@triton.jit
def _flash_attn_bwd_kernel(
# pointers
Q_ptr, K_ptr, V_ptr, O_ptr, DO_ptr,
dQ_ptr, dK_ptr, dV_ptr,
L_ptr, D_ptr,
# strides
stride_qb, stride_qh, stride_qn, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_ob, stride_oh, stride_on, stride_od,
# dims
N: tl.constexpr,
d_head: tl.constexpr,
H: tl.constexpr,
CAUSAL: tl.constexpr,
scale,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Backward pass kernel.
Parallelises over Keys/Values (BLOCK_N) and Batch/Head.
"""
start_n = tl.program_id(0)
off_bh = tl.program_id(1)
off_b = off_bh // H
off_h = off_bh % H
# Tile indices
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, d_head)
# Pointers to K and V tiles [BLOCK_N, d_head]
k_ptrs = K_ptr + off_b * stride_kb + off_h * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
v_ptrs = V_ptr + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
# Load K and V
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# Accumulators for dK and dV
dk = tl.zeros([BLOCK_N, d_head], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, d_head], dtype=tl.float32)
# Loop over Q
for start_m in range(0, N, BLOCK_M):
offs_m = start_m + tl.arange(0, BLOCK_M)
# Load Q and dO
q_ptrs = Q_ptr + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd
do_ptrs = DO_ptr + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_on + offs_d[None, :] * stride_od
q = tl.load(q_ptrs)
do = tl.load(do_ptrs)
# Load L and D
l_ptrs = L_ptr + off_b * H * N + off_h * N + offs_m
d_ptrs = D_ptr + off_b * H * N + off_h * N + offs_m
l_i = tl.load(l_ptrs)
d_i = tl.load(d_ptrs)
# Recompute P = softmax(QK^T)
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
if CAUSAL:
qk = tl.where(offs_m[:, None] >= offs_n[None, :], qk, float("-inf"))
p = tl.exp(qk - l_i[:, None]) # [BLOCK_M, BLOCK_N]
# compute dv = P^T @ do
dv += tl.dot(tl.trans(p.to(tl.float16)), do) # [BLOCK_N, d_head]
# compute dp = do @ v.T
dp = tl.dot(do, tl.trans(v)).to(tl.float32) # [BLOCK_M, BLOCK_N]
# compute ds = P * (dp - D)
ds = p * (dp - d_i[:, None]) * scale # [BLOCK_M, BLOCK_N]
# compute dk = ds^T @ q
dk += tl.dot(tl.trans(ds.to(tl.float16)), q) # [BLOCK_N, d_head]
# compute dq = ds @ k
dq = tl.dot(ds.to(tl.float16), k) # [BLOCK_M, d_head]
# Write back dQ using atomic add to avoid race conditions across K blocks
dq_ptrs = dQ_ptr + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd
tl.atomic_add(dq_ptrs, dq)
# Write back dK and dV
dk_ptrs = dK_ptr + off_b * stride_kb + off_h * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
dv_ptrs = dV_ptr + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
tl.store(dk_ptrs, dk.to(tl.float16))
tl.store(dv_ptrs, dv.to(tl.float16))
# ---------------------------------------------------------------------------
# Python launcher & Autograd Function
# ---------------------------------------------------------------------------
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, block_m=64, block_n=64):
assert q.device.type in ["cuda", "hip", "xpu"] and k.device.type in ["cuda", "hip", "xpu"] and v.device.type in ["cuda", "hip", "xpu"], "Tensors must be on a supported GPU (CUDA, HIP, XPU)"
assert q.dtype == torch.float16, "Only fp16 is currently supported"
assert q.shape == k.shape == v.shape, "Q, K, V must have identical shapes"
B, H, N, d = q.shape
scale = 1.0 / math.sqrt(d)
o = torch.empty_like(q)
l = torch.empty(B, H, N, dtype=torch.float32, device=q.device) # lse buffer
grid = (triton.cdiv(N, block_m), B * H)
_flash_attn_fwd_kernel[grid](
q, k, v, o, l,
# Q strides
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# K strides
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# V strides
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# O strides
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
N=N,
d_head=d,
H=H,
CAUSAL=causal,
scale=scale,
BLOCK_M=block_m,
BLOCK_N=block_n,
)
ctx.save_for_backward(q, k, v, o, l)
ctx.causal = causal
ctx.scale = scale
ctx.block_m = block_m
ctx.block_n = block_n
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l = ctx.saved_tensors
B, H, N, d = q.shape
# dQ requires atomic adds, so it must be initialized to zero
dq = torch.zeros_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
D = torch.empty(B, H, N, dtype=torch.float32, device=q.device)
blkm = ctx.block_m
blkn = ctx.block_n
# 1. Preprocess: compute D = sum(o * do, axis=-1)
grid_pre = (triton.cdiv(N, blkm), B * H)
_attn_bwd_preprocess[grid_pre](
o, do, D,
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
BLOCK_M=blkm, d_head=d, H=H, N=N
)
# 2. Main backward kernel
grid_bwd = (triton.cdiv(N, blkn), B * H)
_flash_attn_bwd_kernel[grid_bwd](
q, k, v, o, do,
dq, dk, dv,
l, D,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
N=N, d_head=d, H=H, CAUSAL=ctx.causal, scale=ctx.scale,
BLOCK_M=blkm, BLOCK_N=blkn,
)
return dq, dk, dv, None, None, None
def flash_attention_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
block_m: int = 64,
block_n: int = 64,
) -> torch.Tensor:
"""
Forward pass of Flash Attention 1.
Args:
q: Query tensor [batch, heads, seq_len, head_dim]
k: Key tensor [batch, heads, seq_len, head_dim]
v: Value tensor [batch, heads, seq_len, head_dim]
causal: Whether to apply a causal mask.
block_m: Tile size along the query (row) dimension.
block_n: Tile size along the key/value (column) dimension.
Returns:
Output tensor [batch, heads, seq_len, head_dim]
"""
return _attention.apply(q, k, v, causal, block_m, block_n)