""" Flash Attention 1 — Triton Kernel ================================== Pure Python/Triton implementation of Flash Attention (Dao et al., 2022). Reference: https://arxiv.org/abs/2205.14135 Algorithm overview ------------------ Flash Attention fuses the softmax(Q Kᵀ / √d) · V computation into a single GPU kernel pass using online softmax (Milakov & Gimelshein, 2018) so that the full N×N attention matrix is never materialised in HBM. This file contains: 1. `_flash_attn_fwd_kernel` forward pass Triton kernel 2. `_flash_attn_bwd_kernel` backward pass Triton kernel (stub, TODO) 3. `flash_attention_forward` Python launcher / autograd function wrapper """ import math import torch import triton import triton.language as tl # --------------------------------------------------------------------------- # Forward kernel # --------------------------------------------------------------------------- @triton.jit def _flash_attn_fwd_kernel( # --- pointers --- Q_ptr, K_ptr, V_ptr, # [B, H, N, d_head] O_ptr, # [B, H, N, d_head] output L_ptr, # [B, H, N] log-sum-exp (for bwd) # --- strides (Q, K, V, O share layout) --- stride_qb, stride_qh, stride_qn, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, stride_vb, stride_vh, stride_vn, stride_vd, stride_ob, stride_oh, stride_on, stride_od, # --- problem dims --- N: tl.constexpr, # sequence length d_head: tl.constexpr, # head dimension H: tl.constexpr, # number of heads CAUSAL: tl.constexpr, # whether to apply causal mask scale, # 1 / sqrt(d_head) # --- tile sizes --- BLOCK_M: tl.constexpr, # rows of Q processed per CTA BLOCK_N: tl.constexpr, # cols of K/V tile ): """ Each program handles one (batch, head, tile-of-rows) triple. Inner loop tiles over K/V columns and accumulates running O, m, l using the online-softmax update from the Flash Attention paper. """ # ----------------------------------------------------------------------- # Identify this program's slice # ----------------------------------------------------------------------- start_m = tl.program_id(0) # tile index along N (row dimension) off_bh = tl.program_id(1) # flat batch×head index off_b = off_bh // H off_h = off_bh % H # Row offsets for this CTA offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BLOCK_M] offs_d = tl.arange(0, d_head) # [d_head] # Base pointer for Q rows assigned to this CTA Q_blk_ptr = ( Q_ptr + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qn # [BLOCK_M, 1] + offs_d[None, :] * stride_qd # [1, d_head] ) # ----------------------------------------------------------------------- # Load Q tile [BLOCK_M, d_head] # ----------------------------------------------------------------------- q = tl.load(Q_blk_ptr, mask=offs_m[:, None] < N, other=0.0) # ----------------------------------------------------------------------- # Running accumulators (online softmax) # ----------------------------------------------------------------------- m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) # row max l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # normaliser o_i = tl.zeros([BLOCK_M, d_head], dtype=tl.float32) # output acc # ----------------------------------------------------------------------- # Iterate over K/V tiles # ----------------------------------------------------------------------- for start_n in range(0, N, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) # [BLOCK_N] # Load K tile [d_head, BLOCK_N] (transposed for matmul) K_blk_ptr = ( K_ptr + off_b * stride_kb + off_h * stride_kh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd ) k = tl.load(K_blk_ptr, mask=offs_n[None, :] < N, other=0.0) # Load V tile [BLOCK_N, d_head] V_blk_ptr = ( V_ptr + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd ) v = tl.load(V_blk_ptr, mask=offs_n[:, None] < N, other=0.0) # --- QKᵀ scaled dot-product [BLOCK_M, BLOCK_N] --- s = tl.dot(q, k) * scale if CAUSAL: s = tl.where(offs_m[:, None] >= offs_n[None, :], s, float("-inf")) # --- online-softmax update --- m_ij = tl.max(s, axis=1) # [BLOCK_M] m_new = tl.maximum(m_i, m_ij) alpha = tl.exp(m_i - m_new) # rescale old acc p = tl.exp(s - m_new[:, None]) # [BLOCK_M, BLOCK_N] l_new = alpha * l_i + tl.sum(p, axis=1) # [BLOCK_M] o_i = alpha[:, None] * o_i + tl.dot(p.to(tl.float16), v) m_i = m_new l_i = l_new # ----------------------------------------------------------------------- # Normalise and write output # ----------------------------------------------------------------------- o_i = o_i / l_i[:, None] O_blk_ptr = ( O_ptr + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_on + offs_d[None, :] * stride_od ) tl.store(O_blk_ptr, o_i.to(tl.float16), mask=offs_m[:, None] < N) # Store log-sum-exp for backward pass L_blk_ptr = L_ptr + off_b * H * N + off_h * N + offs_m lse = m_i + tl.log(l_i) tl.store(L_blk_ptr, lse, mask=offs_m < N) # --------------------------------------------------------------------------- # Backward kernels # --------------------------------------------------------------------------- @triton.jit def _attn_bwd_preprocess( # pointers O_ptr, DO_ptr, D_ptr, # strides stride_ob, stride_oh, stride_on, stride_od, # dims BLOCK_M: tl.constexpr, d_head: tl.constexpr, H: tl.constexpr, N: tl.constexpr, ): """ Precomputes D = sum(O * dO, axis=-1) for the backward pass. """ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_bh = tl.program_id(1) off_b = off_bh // H off_h = off_bh % H off_d = tl.arange(0, d_head) # Load O and dO o_ptrs = O_ptr + off_b * stride_ob + off_h * stride_oh + off_m[:, None] * stride_on + off_d[None, :] * stride_od do_ptrs = DO_ptr + off_b * stride_ob + off_h * stride_oh + off_m[:, None] * stride_on + off_d[None, :] * stride_od o = tl.load(o_ptrs).to(tl.float32) do = tl.load(do_ptrs).to(tl.float32) delta = tl.sum(o * do, axis=1) # [BLOCK_M] d_ptrs = D_ptr + off_b * H * N + off_h * N + off_m tl.store(d_ptrs, delta, mask=off_m < N) @triton.jit def _flash_attn_bwd_kernel( # pointers Q_ptr, K_ptr, V_ptr, O_ptr, DO_ptr, dQ_ptr, dK_ptr, dV_ptr, L_ptr, D_ptr, # strides stride_qb, stride_qh, stride_qn, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, stride_vb, stride_vh, stride_vn, stride_vd, stride_ob, stride_oh, stride_on, stride_od, # dims N: tl.constexpr, d_head: tl.constexpr, H: tl.constexpr, CAUSAL: tl.constexpr, scale, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): """ Backward pass kernel. Parallelises over Keys/Values (BLOCK_N) and Batch/Head. """ start_n = tl.program_id(0) off_bh = tl.program_id(1) off_b = off_bh // H off_h = off_bh % H # Tile indices offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_d = tl.arange(0, d_head) # Pointers to K and V tiles [BLOCK_N, d_head] k_ptrs = K_ptr + off_b * stride_kb + off_h * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd v_ptrs = V_ptr + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd # Load K and V k = tl.load(k_ptrs) v = tl.load(v_ptrs) # Accumulators for dK and dV dk = tl.zeros([BLOCK_N, d_head], dtype=tl.float32) dv = tl.zeros([BLOCK_N, d_head], dtype=tl.float32) # Loop over Q for start_m in range(0, N, BLOCK_M): offs_m = start_m + tl.arange(0, BLOCK_M) # Load Q and dO q_ptrs = Q_ptr + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd do_ptrs = DO_ptr + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_on + offs_d[None, :] * stride_od q = tl.load(q_ptrs) do = tl.load(do_ptrs) # Load L and D l_ptrs = L_ptr + off_b * H * N + off_h * N + offs_m d_ptrs = D_ptr + off_b * H * N + off_h * N + offs_m l_i = tl.load(l_ptrs) d_i = tl.load(d_ptrs) # Recompute P = softmax(QK^T) qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N] if CAUSAL: qk = tl.where(offs_m[:, None] >= offs_n[None, :], qk, float("-inf")) p = tl.exp(qk - l_i[:, None]) # [BLOCK_M, BLOCK_N] # compute dv = P^T @ do dv += tl.dot(tl.trans(p.to(tl.float16)), do) # [BLOCK_N, d_head] # compute dp = do @ v.T dp = tl.dot(do, tl.trans(v)).to(tl.float32) # [BLOCK_M, BLOCK_N] # compute ds = P * (dp - D) ds = p * (dp - d_i[:, None]) * scale # [BLOCK_M, BLOCK_N] # compute dk = ds^T @ q dk += tl.dot(tl.trans(ds.to(tl.float16)), q) # [BLOCK_N, d_head] # compute dq = ds @ k dq = tl.dot(ds.to(tl.float16), k) # [BLOCK_M, d_head] # Write back dQ using atomic add to avoid race conditions across K blocks dq_ptrs = dQ_ptr + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd tl.atomic_add(dq_ptrs, dq) # Write back dK and dV dk_ptrs = dK_ptr + off_b * stride_kb + off_h * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd dv_ptrs = dV_ptr + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd tl.store(dk_ptrs, dk.to(tl.float16)) tl.store(dv_ptrs, dv.to(tl.float16)) # --------------------------------------------------------------------------- # Python launcher & Autograd Function # --------------------------------------------------------------------------- class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, block_m=64, block_n=64): assert q.device.type in ["cuda", "hip", "xpu"] and k.device.type in ["cuda", "hip", "xpu"] and v.device.type in ["cuda", "hip", "xpu"], "Tensors must be on a supported GPU (CUDA, HIP, XPU)" assert q.dtype == torch.float16, "Only fp16 is currently supported" assert q.shape == k.shape == v.shape, "Q, K, V must have identical shapes" B, H, N, d = q.shape scale = 1.0 / math.sqrt(d) o = torch.empty_like(q) l = torch.empty(B, H, N, dtype=torch.float32, device=q.device) # lse buffer grid = (triton.cdiv(N, block_m), B * H) _flash_attn_fwd_kernel[grid]( q, k, v, o, l, # Q strides q.stride(0), q.stride(1), q.stride(2), q.stride(3), # K strides k.stride(0), k.stride(1), k.stride(2), k.stride(3), # V strides v.stride(0), v.stride(1), v.stride(2), v.stride(3), # O strides o.stride(0), o.stride(1), o.stride(2), o.stride(3), N=N, d_head=d, H=H, CAUSAL=causal, scale=scale, BLOCK_M=block_m, BLOCK_N=block_n, ) ctx.save_for_backward(q, k, v, o, l) ctx.causal = causal ctx.scale = scale ctx.block_m = block_m ctx.block_n = block_n return o @staticmethod def backward(ctx, do): q, k, v, o, l = ctx.saved_tensors B, H, N, d = q.shape # dQ requires atomic adds, so it must be initialized to zero dq = torch.zeros_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) D = torch.empty(B, H, N, dtype=torch.float32, device=q.device) blkm = ctx.block_m blkn = ctx.block_n # 1. Preprocess: compute D = sum(o * do, axis=-1) grid_pre = (triton.cdiv(N, blkm), B * H) _attn_bwd_preprocess[grid_pre]( o, do, D, o.stride(0), o.stride(1), o.stride(2), o.stride(3), BLOCK_M=blkm, d_head=d, H=H, N=N ) # 2. Main backward kernel grid_bwd = (triton.cdiv(N, blkn), B * H) _flash_attn_bwd_kernel[grid_bwd]( q, k, v, o, do, dq, dk, dv, l, D, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), N=N, d_head=d, H=H, CAUSAL=ctx.causal, scale=ctx.scale, BLOCK_M=blkm, BLOCK_N=blkn, ) return dq, dk, dv, None, None, None def flash_attention_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False, block_m: int = 64, block_n: int = 64, ) -> torch.Tensor: """ Forward pass of Flash Attention 1. Args: q: Query tensor [batch, heads, seq_len, head_dim] k: Key tensor [batch, heads, seq_len, head_dim] v: Value tensor [batch, heads, seq_len, head_dim] causal: Whether to apply a causal mask. block_m: Tile size along the query (row) dimension. block_n: Tile size along the key/value (column) dimension. Returns: Output tensor [batch, heads, seq_len, head_dim] """ return _attention.apply(q, k, v, causal, block_m, block_n)