#!/usr/bin/env python3 """ Triton-fused Chunked Sparse Backward Pass. Replaces the Python for-loop over active chunks with fused Triton kernels: 1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c 2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c 3. sparse_fwd: Y[:, c*CS:(c+1)*CS] = X @ W[c*CS:(c+1)*CS, :].T for active c Benchmark against the Python-loop baseline at various d_model sizes. """ import math import os import random import time import urllib.request from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl try: import tiktoken except ImportError: raise ImportError("pip install tiktoken") # ═══════════════════════════════════════════════════════════════════ # TRITON KERNELS # ═══════════════════════════════════════════════════════════════════ # ── Kernel 1: Sparse dW ────────────────────────────────────────── # For each active chunk c: # grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X # # In terms of shapes: # grad_Y: (M, d_out), X: (M, d_in), W: (d_out, d_in) # For chunk c: rows c*CS..(c+1)*CS of W get grad from cols c*CS..(c+1)*CS of grad_Y # # Grid: (num_active * ceil(CS/BN), ceil(d_in/BK)) # pid0 encodes (active_chunk_linear_id, N-block within CS) # pid1 encodes K-block within d_in @triton.autotune( configs=[ triton.Config({'BN': 32, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4), triton.Config({'BN': 64, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4), triton.Config({'BN': 64, 'BK': 128, 'BM': 32}, num_stages=3, num_warps=4), triton.Config({'BN': 32, 'BK': 128, 'BM': 64}, num_stages=3, num_warps=4), triton.Config({'BN': 64, 'BK': 64, 'BM': 64}, num_stages=4, num_warps=4), ], key=['M', 'd_in', 'CS'], ) @triton.jit def _sparse_bwd_dW_kernel( X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr, M, d_in, d_out, num_active, stride_xm, stride_xk, stride_dym, stride_dyn, stride_dwn, stride_dwk, CS: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr, ): """Compute dW tiles for active chunks. Each program writes one [BN, BK] tile.""" pid0 = tl.program_id(0) pid1 = tl.program_id(1) N_BLOCKS_PER_CHUNK = tl.cdiv(CS, BN) chunk_linear_id = pid0 // N_BLOCKS_PER_CHUNK n_block_id = pid0 % N_BLOCKS_PER_CHUNK k_block_id = pid1 if chunk_linear_id >= num_active: return chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id) chunk_start = chunk_idx * CS # Tile ranges rn = n_block_id * BN + tl.arange(0, BN) # rows of dW (= cols of chunk in dY) rk = k_block_id * BK + tl.arange(0, BK) # cols of dW (= cols of X) n_abs = chunk_start + rn # absolute column indices in dY n_mask = rn < CS k_mask = rk < d_in # Accumulate dY[:, chunk_cols].T @ X[:, k_cols] over M-tiles acc = tl.zeros((BN, BK), dtype=tl.float32) for m_start in range(0, M, BM): rm = m_start + tl.arange(0, BM) m_mask = rm < M # Load X tile: (BM, BK) x = tl.load( X_ptr + rm[:, None] * stride_xm + rk[None, :] * stride_xk, mask=m_mask[:, None] & k_mask[None, :], other=0.0, ) # Load dY tile: (BM, BN) dy = tl.load( dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn, mask=m_mask[:, None] & n_mask[None, :], other=0.0, ) # dY.T @ X -> (BN, BK) acc = tl.dot(tl.trans(dy), x, acc=acc) # Write to dW: row = chunk_start + rn, col = rk # dW layout: (d_out, d_in) dw_ptrs = dW_ptr + n_abs[:, None] * stride_dwn + rk[None, :] * stride_dwk tl.store(dw_ptrs, acc.to(dW_ptr.dtype.element_ty), mask=n_mask[:, None] & k_mask[None, :]) def sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out): """Fused Triton kernel for sparse dW computation.""" M, d_in = X.shape num_active = active_chunks.shape[0] CS = chunk_size dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype) if num_active == 0: return dW chunk_ids = active_chunks.to(torch.int32).contiguous() grid = lambda META: ( num_active * triton.cdiv(CS, META['BN']), triton.cdiv(d_in, META['BK']), ) _sparse_bwd_dW_kernel[grid]( X, dY, dW, chunk_ids, M, d_in, d_out, num_active, X.stride(0), X.stride(1), dY.stride(0), dY.stride(1), dW.stride(0), dW.stride(1), CS=CS, ) return dW # ── Kernel 2: Sparse dX ────────────────────────────────────────── # For each active chunk c: # grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] # # Grid: (ceil(M/BM), ceil(d_in/BK)) # Each program accumulates contributions from ALL active chunks. @triton.autotune( configs=[ triton.Config({'BM': 32, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4), triton.Config({'BM': 64, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4), triton.Config({'BM': 64, 'BK': 128, 'BN': 64}, num_stages=3, num_warps=4), triton.Config({'BM': 32, 'BK': 128, 'BN': 32}, num_stages=4, num_warps=4), ], key=['M', 'd_in', 'CS'], ) @triton.jit def _sparse_bwd_dX_kernel( dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, M, d_in, d_out, num_active, stride_dym, stride_dyn, stride_wn, stride_wk, stride_dxm, stride_dxk, CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, ): """Compute dX tiles by summing over active chunks.""" pid_m = tl.program_id(0) pid_k = tl.program_id(1) rm = pid_m * BM + tl.arange(0, BM) rk = pid_k * BK + tl.arange(0, BK) m_mask = rm < M k_mask = rk < d_in acc = tl.zeros((BM, BK), dtype=tl.float32) # Sum over all active chunks for i in range(num_active): chunk_idx = tl.load(chunk_ids_ptr + i) chunk_start = chunk_idx * CS # Tile over BN within the chunk for n_start in range(0, CS, BN): rn = n_start + tl.arange(0, BN) n_abs = chunk_start + rn n_mask = rn < CS # Load dY tile: (BM, BN) dy = tl.load( dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn, mask=m_mask[:, None] & n_mask[None, :], other=0.0, ) # Load W tile: (BN, BK) — W[chunk_start+rn, rk] w = tl.load( W_ptr + n_abs[:, None] * stride_wn + rk[None, :] * stride_wk, mask=n_mask[:, None] & k_mask[None, :], other=0.0, ) # dY @ W -> (BM, BK) acc = tl.dot(dy, w, acc=acc) # Write dX dx_ptrs = dX_ptr + rm[:, None] * stride_dxm + rk[None, :] * stride_dxk tl.store(dx_ptrs, acc.to(dX_ptr.dtype.element_ty), mask=m_mask[:, None] & k_mask[None, :]) def sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in): """Fused Triton kernel for sparse dX computation.""" num_active = active_chunks.shape[0] CS = chunk_size dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype) if num_active == 0: return dX chunk_ids = active_chunks.to(torch.int32).contiguous() grid = lambda META: ( triton.cdiv(M, META['BM']), triton.cdiv(d_in, META['BK']), ) _sparse_bwd_dX_kernel[grid]( dY, W, dX, chunk_ids, M, d_in, dY.shape[1], num_active, dY.stride(0), dY.stride(1), W.stride(0), W.stride(1), dX.stride(0), dX.stride(1), CS=CS, ) return dX # ── Kernel 3: Sparse dBias ──────────────────────────────────────── # Simple: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0) @triton.jit def _sparse_bwd_dbias_kernel( dY_ptr, dB_ptr, chunk_ids_ptr, M, d_out, num_active, stride_dym, stride_dyn, CS: tl.constexpr, BM: tl.constexpr, ): pid = tl.program_id(0) # one per (active_chunk, col_within_chunk) chunk_linear = pid // CS col_in_chunk = pid % CS if chunk_linear >= num_active: return chunk_idx = tl.load(chunk_ids_ptr + chunk_linear) col_abs = chunk_idx * CS + col_in_chunk acc = 0.0 for m_start in range(0, M, BM): rm = m_start + tl.arange(0, BM) m_mask = rm < M vals = tl.load(dY_ptr + rm * stride_dym + col_abs * stride_dyn, mask=m_mask, other=0.0) acc += tl.sum(vals) tl.store(dB_ptr + col_abs, acc.to(dB_ptr.dtype.element_ty)) def sparse_bwd_dbias(dY, active_chunks, chunk_size, d_out): M = dY.shape[0] num_active = active_chunks.shape[0] dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype) if num_active == 0: return dB chunk_ids = active_chunks.to(torch.int32).contiguous() BM = 128 grid = (num_active * chunk_size,) _sparse_bwd_dbias_kernel[grid]( dY, dB, chunk_ids, M, d_out, num_active, dY.stride(0), dY.stride(1), CS=chunk_size, BM=BM, ) return dB # ═══════════════════════════════════════════════════════════════════ # AUTOGRAD FUNCTION: Triton-fused # ═══════════════════════════════════════════════════════════════════ class TritonChunkedSparseLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx): ctx.save_for_backward(x, weight, active_chunks) ctx.has_bias = bias is not None ctx.sparse_dx = sparse_dx ctx.chunk_size = chunk_size return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_y): x, weight, active_chunks = ctx.saved_tensors cs = ctx.chunk_size d_out, d_in = weight.shape x_flat = x.reshape(-1, d_in) gy_flat = grad_y.reshape(-1, d_out) M = x_flat.shape[0] # grad_W via Triton grad_w = sparse_bwd_dW(x_flat, gy_flat, active_chunks, cs, d_out) # grad_bias via Triton grad_b = sparse_bwd_dbias(gy_flat, active_chunks, cs, d_out) if ctx.has_bias else None # grad_X if ctx.sparse_dx: grad_x_flat = sparse_bwd_dX(gy_flat, weight, active_chunks, cs, M, d_in) else: grad_x_flat = gy_flat @ weight # dense dX return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None # ═══════════════════════════════════════════════════════════════════ # AUTOGRAD FUNCTION: Python-loop baseline (for comparison) # ═══════════════════════════════════════════════════════════════════ class PythonLoopSparseLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx): ctx.save_for_backward(x, weight, active_chunks) ctx.has_bias = bias is not None ctx.sparse_dx = sparse_dx ctx.chunk_size = chunk_size return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_y): x, weight, active_chunks = ctx.saved_tensors cs = ctx.chunk_size x_flat = x.reshape(-1, x.shape[-1]) gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) grad_w = torch.zeros_like(weight) grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None if ctx.sparse_dx: grad_x_flat = torch.zeros_like(x_flat) else: grad_x_flat = gy_flat @ weight for c in active_chunks.tolist(): s, e = c * cs, (c + 1) * cs gy_slice = gy_flat[:, s:e] grad_w[s:e, :] = gy_slice.t() @ x_flat if ctx.has_bias: grad_b[s:e] = gy_slice.sum(0) if ctx.sparse_dx: grad_x_flat += gy_slice @ weight[s:e, :] return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None # ═══════════════════════════════════════════════════════════════════ # CORRECTNESS TEST # ═══════════════════════════════════════════════════════════════════ def test_correctness(): print("Testing correctness...") torch.manual_seed(42) device = "cuda" for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 32)]: M = 2048 # B*T n_chunks = d_out // cs n_active = max(1, int(0.1 * n_chunks)) active = torch.randperm(n_chunks, device=device)[:n_active].sort().values x = torch.randn(M, d_in, device=device, requires_grad=False) w = torch.randn(d_out, d_in, device=device, requires_grad=False) b = torch.randn(d_out, device=device, requires_grad=False) gy = torch.randn(M, d_out, device=device, requires_grad=False) # Reference: Python loop ref_dw = torch.zeros_like(w) ref_db = torch.zeros_like(b) ref_dx = gy @ w # dense dX for c in active.tolist(): s, e = c * cs, (c + 1) * cs ref_dw[s:e] = gy[:, s:e].t() @ x ref_db[s:e] = gy[:, s:e].sum(0) # Triton tri_dw = sparse_bwd_dW(x, gy, active, cs, d_out) tri_db = sparse_bwd_dbias(gy, active, cs, d_out) tri_dx_sparse = sparse_bwd_dX(gy, w, active, cs, M, d_in) # Compare dw_err = (tri_dw - ref_dw).abs().max().item() db_err = (tri_db - ref_db).abs().max().item() # For sparse dX, reference ref_dx_sparse = torch.zeros_like(x) for c in active.tolist(): s, e = c * cs, (c + 1) * cs ref_dx_sparse += gy[:, s:e] @ w[s:e] dx_err = (tri_dx_sparse - ref_dx_sparse).abs().max().item() status = "✓" if dw_err < 1e-2 and db_err < 1e-2 and dx_err < 1e-2 else "✗" print(f" {status} d_in={d_in}, d_out={d_out}, cs={cs}: dW_err={dw_err:.6f}, dB_err={db_err:.6f}, dX_err={dx_err:.6f}") print() # ═══════════════════════════════════════════════════════════════════ # BENCHMARK # ═══════════════════════════════════════════════════════════════════ def benchmark(): print("="*80) print("BENCHMARK: Triton Fused vs Python Loop vs Dense") print("="*80) device = "cuda" B, T = 8, 256 M = B * T cs = 64 af = 0.10 warmup_iters = 10 bench_iters = 50 print(f"\nM={M} (B={B}, T={T}), chunk_size={cs}, active_frac={af}") print(f"{'d_model':>7} | {'d_out':>7} | {'active':>6} | {'Dense':>10} | {'PyLoop':>10} | {'Triton':>10} | {'Tri/Dense':>10} | {'Tri/PyLoop':>10}") print("-" * 95) for d_in in [256, 512, 768, 1024, 1536, 2048]: d_out = 4 * d_in n_chunks = d_out // cs n_active = max(1, int(af * n_chunks)) active = torch.randperm(n_chunks, device=device)[:n_active].sort().values x = torch.randn(M, d_in, device=device) w = torch.randn(d_out, d_in, device=device) b = torch.randn(d_out, device=device) gy = torch.randn(M, d_out, device=device) # Dense backward (dW + dX + dB) def dense_bwd(): dw = gy.t() @ x dx = gy @ w db = gy.sum(0) return dw, dx, db # Python loop backward def pyloop_bwd(): dw = torch.zeros_like(w) db = torch.zeros_like(b) dx = gy @ w # dense dX for c in active.tolist(): s, e = c * cs, (c + 1) * cs dw[s:e] = gy[:, s:e].t() @ x db[s:e] = gy[:, s:e].sum(0) return dw, dx, db # Triton fused backward def triton_bwd(): dw = sparse_bwd_dW(x, gy, active, cs, d_out) dx = gy @ w # dense dX (same as pyloop) db = sparse_bwd_dbias(gy, active, cs, d_out) return dw, dx, db # Warmup for _ in range(warmup_iters): dense_bwd(); pyloop_bwd(); triton_bwd() torch.cuda.synchronize() # Bench dense torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(bench_iters): dense_bwd() torch.cuda.synchronize(); dense_time = (time.perf_counter() - t0) / bench_iters # Bench pyloop torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(bench_iters): pyloop_bwd() torch.cuda.synchronize(); pyloop_time = (time.perf_counter() - t0) / bench_iters # Bench triton torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(bench_iters): triton_bwd() torch.cuda.synchronize(); triton_time = (time.perf_counter() - t0) / bench_iters tri_vs_dense = dense_time / triton_time tri_vs_pyloop = pyloop_time / triton_time print(f"{d_in:>7} | {d_out:>7} | {n_active:>6} | {dense_time*1000:>9.2f}ms | {pyloop_time*1000:>9.2f}ms | {triton_time*1000:>9.2f}ms | {tri_vs_dense:>9.2f}x | {tri_vs_pyloop:>9.2f}x") # Also benchmark with sparse_dX (Triton dX kernel) print(f"\n{'='*80}") print("With Triton sparse_dX (both dW and dX are sparse):") print(f"{'d_model':>7} | {'Dense':>10} | {'Triton_all':>10} | {'Speedup':>10}") print("-" * 50) for d_in in [512, 1024, 2048]: d_out = 4 * d_in n_chunks = d_out // cs n_active = max(1, int(af * n_chunks)) active = torch.randperm(n_chunks, device=device)[:n_active].sort().values x = torch.randn(M, d_in, device=device) w = torch.randn(d_out, d_in, device=device) gy = torch.randn(M, d_out, device=device) def dense_full(): dw = gy.t() @ x; dx = gy @ w; return dw, dx def triton_full(): dw = sparse_bwd_dW(x, gy, active, cs, d_out) dx = sparse_bwd_dX(gy, w, active, cs, M, d_in) return dw, dx for _ in range(warmup_iters): dense_full(); triton_full() torch.cuda.synchronize() torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(bench_iters): dense_full() torch.cuda.synchronize(); dt = (time.perf_counter() - t0) / bench_iters torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(bench_iters): triton_full() torch.cuda.synchronize(); tt = (time.perf_counter() - t0) / bench_iters print(f"{d_in:>7} | {dt*1000:>9.2f}ms | {tt*1000:>9.2f}ms | {dt/tt:>9.2f}x") # ═══════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════ if __name__ == "__main__": test_correctness() benchmark()