#!/usr/bin/env python3 """ Triton-fused Chunked Sparse Backward Pass — v2. Fixes from review: 1. Bias folded into dW kernel (kills the uncoalesced column-striding bias kernel) 2. block_ptr / TMA for dW and dX loads (hardware-accelerated 2D tile fetch) 3. No autotune (fixed config to eliminate compilation overhead + divergence risk) Benchmarks v1 (manual ptrs + separate bias) vs v2 (block_ptr + fused bias) vs Python-loop baseline vs Dense. """ import math, os, time import torch, torch.nn as nn, torch.nn.functional as F import triton, triton.language as tl # ═══════════════════════════════════════════════════════════════════ # V2 KERNELS — block_ptr + fused bias # ═══════════════════════════════════════════════════════════════════ # Fixed tile sizes — no autotune. CS=64 means one N-block covers the whole chunk. # BM=64: token tile for the M-reduction loop # BK=64: tile along d_in # BN=64: tile along chunk (== CS for chunk_size=64, so 1 block per chunk) @triton.jit def _v2_sparse_bwd_dW_kernel( X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr, M, d_in, d_out, num_active, stride_xm, stride_xk, stride_dym, stride_dyn, stride_dwn, stride_dwk, HAS_BIAS: tl.constexpr, CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr, ): """ Each program computes one [CS, BK] tile of dW for one active chunk, plus the [CS] bias slice if HAS_BIAS. Grid: (num_active, ceil(d_in / BK)) Since CS fits in one tile (CS==64, BN==CS), pid0 == chunk index directly. """ chunk_linear_id = tl.program_id(0) k_block_id = tl.program_id(1) if chunk_linear_id >= num_active: return chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id) chunk_start = chunk_idx * CS k_offset = k_block_id * BK # Block pointer for dY transposed: we want dY.T[chunk_cols, :] = shape (CS, M) # dY is (M, d_out) row-major. Transposed view: shape=(d_out, M), strides=(stride_dyn, stride_dym) dy_block_ptr = tl.make_block_ptr( base=dY_ptr, shape=(d_out, M), strides=(stride_dyn, stride_dym), offsets=(chunk_start, 0), block_shape=(CS, BM), order=(1, 0), ) # Block pointer for X: shape (M, d_in), reading (BM, BK) tiles x_block_ptr = tl.make_block_ptr( base=X_ptr, shape=(M, d_in), strides=(stride_xm, stride_xk), offsets=(0, k_offset), block_shape=(BM, BK), order=(1, 0), ) # Accumulators acc_dw = tl.zeros((CS, BK), dtype=tl.float32) # Bias accumulator: only on the first k-block to avoid redundant work compute_bias = HAS_BIAS and (k_block_id == 0) acc_db = tl.zeros((CS,), dtype=tl.float32) # Reduction over M for m_start in range(0, M, BM): dy_t = tl.load(dy_block_ptr, boundary_check=(0, 1)) # (CS, BM) x = tl.load(x_block_ptr, boundary_check=(0, 1)) # (BM, BK) # dW += dY.T @ X -> (CS, BM) @ (BM, BK) = (CS, BK) acc_dw = tl.dot(dy_t, x, acc=acc_dw) # Bias: sum over M dimension of dY chunk columns # dy_t is (CS, BM) = transposed chunk. Sum along dim=1 = sum over tokens. if compute_bias: acc_db += tl.sum(dy_t, axis=1) dy_block_ptr = tl.advance(dy_block_ptr, (0, BM)) x_block_ptr = tl.advance(x_block_ptr, (BM, 0)) # Store dW tile: dW[chunk_start:chunk_start+CS, k_offset:k_offset+BK] dw_block_ptr = tl.make_block_ptr( base=dW_ptr, shape=(d_out, d_in), strides=(stride_dwn, stride_dwk), offsets=(chunk_start, k_offset), block_shape=(CS, BK), order=(1, 0), ) tl.store(dw_block_ptr, acc_dw.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1)) # Store bias (only from k_block_id == 0) if compute_bias: rn = chunk_start + tl.arange(0, CS) n_mask = rn < d_out tl.store(dB_ptr + rn, acc_db.to(dB_ptr.dtype.element_ty), mask=n_mask) def v2_sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out, bias=True): """Fused dW + dBias via block_ptr kernel.""" M, d_in = X.shape num_active = active_chunks.shape[0] CS = chunk_size dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype) dB = torch.zeros(d_out, device=X.device, dtype=X.dtype) if bias else None if num_active == 0: return dW, dB chunk_ids = active_chunks.to(torch.int32).contiguous() BK = 64 BM = 64 grid = (num_active, triton.cdiv(d_in, BK)) _v2_sparse_bwd_dW_kernel[grid]( X, dY, dW, dB if bias else X, # dummy ptr if no bias chunk_ids, M, d_in, d_out, num_active, X.stride(0), X.stride(1), dY.stride(0), dY.stride(1), dW.stride(0), dW.stride(1), HAS_BIAS=bias, CS=CS, BK=BK, BM=BM, ) return dW, dB # ── V2 dX kernel with block_ptr ── @triton.jit def _v2_sparse_bwd_dX_kernel( dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, M, d_in, d_out, num_active, stride_dym, stride_dyn, stride_wn, stride_wk, stride_dxm, stride_dxk, CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr, ): """ Each program computes one [BM, BK] tile of dX by accumulating over active chunks. Grid: (ceil(M/BM), ceil(d_in/BK)) """ pid_m = tl.program_id(0) pid_k = tl.program_id(1) m_offset = pid_m * BM k_offset = pid_k * BK acc = tl.zeros((BM, BK), dtype=tl.float32) for i in range(num_active): chunk_idx = tl.load(chunk_ids_ptr + i) chunk_start = chunk_idx * CS # dY tile: (BM, CS) at [m_offset, chunk_start] dy_block_ptr = tl.make_block_ptr( base=dY_ptr, shape=(M, d_out), strides=(stride_dym, stride_dyn), offsets=(m_offset, chunk_start), block_shape=(BM, CS), order=(1, 0), ) # W tile: (CS, BK) at [chunk_start, k_offset] w_block_ptr = tl.make_block_ptr( base=W_ptr, shape=(d_out, d_in), strides=(stride_wn, stride_wk), offsets=(chunk_start, k_offset), block_shape=(CS, BK), order=(1, 0), ) dy = tl.load(dy_block_ptr, boundary_check=(0, 1)) # (BM, CS) w = tl.load(w_block_ptr, boundary_check=(0, 1)) # (CS, BK) # dY @ W -> (BM, BK) acc = tl.dot(dy, w, acc=acc) # Store dX tile dx_block_ptr = tl.make_block_ptr( base=dX_ptr, shape=(M, d_in), strides=(stride_dxm, stride_dxk), offsets=(m_offset, k_offset), block_shape=(BM, BK), order=(1, 0), ) tl.store(dx_block_ptr, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1)) def v2_sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in): """Fused dX via block_ptr kernel.""" num_active = active_chunks.shape[0] d_out = dY.shape[1] CS = chunk_size dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype) if num_active == 0: return dX chunk_ids = active_chunks.to(torch.int32).contiguous() BM = 64 BK = 64 grid = (triton.cdiv(M, BM), triton.cdiv(d_in, BK)) _v2_sparse_bwd_dX_kernel[grid]( dY, W, dX, chunk_ids, M, d_in, d_out, num_active, dY.stride(0), dY.stride(1), W.stride(0), W.stride(1), dX.stride(0), dX.stride(1), CS=CS, BM=BM, BK=BK, ) return dX # ═══════════════════════════════════════════════════════════════════ # V1 KERNELS (old, for comparison) — import from triton_sparse.py # ═══════════════════════════════════════════════════════════════════ from triton_sparse import ( sparse_bwd_dW as v1_sparse_bwd_dW, sparse_bwd_dX as v1_sparse_bwd_dX, sparse_bwd_dbias as v1_sparse_bwd_dbias, ) # ═══════════════════════════════════════════════════════════════════ # CORRECTNESS TEST # ═══════════════════════════════════════════════════════════════════ def test_correctness(): print("V2 Correctness Tests") print("=" * 60) device = "cuda" torch.manual_seed(42) for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 64)]: M = 2048 n_chunks = d_out // cs n_active = max(1, int(0.1 * n_chunks)) active = torch.randperm(n_chunks, device=device)[:n_active].sort().values x = torch.randn(M, d_in, device=device) w = torch.randn(d_out, d_in, device=device) gy = torch.randn(M, d_out, device=device) # Reference ref_dw = torch.zeros_like(w) ref_db = torch.zeros(d_out, device=device) for c in active.tolist(): s, e = c * cs, (c + 1) * cs ref_dw[s:e] = gy[:, s:e].t() @ x ref_db[s:e] = gy[:, s:e].sum(0) ref_dx = torch.zeros_like(x) for c in active.tolist(): s, e = c * cs, (c + 1) * cs ref_dx += gy[:, s:e] @ w[s:e] # V2 v2_dw, v2_db = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=True) v2_dx = v2_sparse_bwd_dX(gy, w, active, cs, M, d_in) dw_err = (v2_dw - ref_dw).abs().max().item() db_err = (v2_db - ref_db).abs().max().item() dx_err = (v2_dx - ref_dx).abs().max().item() ok = dw_err < 0.01 and db_err < 0.01 and dx_err < 0.01 print(f" {'✓' if ok else '✗'} d_in={d_in} d_out={d_out} cs={cs}: dW={dw_err:.6f} dB={db_err:.6f} dX={dx_err:.6f}") print() # ═══════════════════════════════════════════════════════════════════ # BENCHMARK # ═══════════════════════════════════════════════════════════════════ def benchmark(): print("=" * 100) print("BENCHMARK: Dense vs PyLoop vs V1-Triton vs V2-Triton (block_ptr + fused bias)") print("=" * 100) device = "cuda" B, T = 8, 256 M = B * T cs = 64 af = 0.10 warmup = 20 iters = 100 print(f"\nM={M}, chunk_size={cs}, active_frac={af}, {iters} iters after {warmup} warmup") print(f"{'d':>5} | {'ffn':>5} | {'act':>3} | {'Dense':>9} | {'PyLoop':>9} | {'V1-Tri':>9} | {'V2-Tri':>9} | {'V2/Dense':>9} | {'V2/PyLoop':>9} | {'V2/V1':>9}") print("-" * 105) for d_in in [256, 512, 768, 1024, 1536, 2048]: d_out = 4 * d_in nc = d_out // cs na = max(1, int(af * nc)) active = torch.randperm(nc, device=device)[:na].sort().values x = torch.randn(M, d_in, device=device) w = torch.randn(d_out, d_in, device=device) gy = torch.randn(M, d_out, device=device) def dense(): return gy.t() @ x, gy @ w, gy.sum(0) def pyloop(): dw = torch.zeros_like(w); db = torch.zeros(d_out, device=device) dx = gy @ w for c in active.tolist(): s, e = c*cs, (c+1)*cs dw[s:e] = gy[:, s:e].t() @ x db[s:e] = gy[:, s:e].sum(0) return dw, dx, db def v1_tri(): dw = v1_sparse_bwd_dW(x, gy, active, cs, d_out) dx = gy @ w db = v1_sparse_bwd_dbias(gy, active, cs, d_out) return dw, dx, db def v2_tri(): dw, db = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=True) dx = gy @ w return dw, dx, db # Warmup all for _ in range(warmup): dense(); pyloop(); v1_tri(); v2_tri() torch.cuda.synchronize() times = {} for name, fn in [("dense", dense), ("pyloop", pyloop), ("v1", v1_tri), ("v2", v2_tri)]: torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(iters): fn() torch.cuda.synchronize() times[name] = (time.perf_counter() - t0) / iters td, tp, t1, t2 = times["dense"], times["pyloop"], times["v1"], times["v2"] print(f"{d_in:>5} | {d_out:>5} | {na:>3} | {td*1000:>8.2f}ms | {tp*1000:>8.2f}ms | {t1*1000:>8.2f}ms | {t2*1000:>8.2f}ms | {td/t2:>8.2f}x | {tp/t2:>8.2f}x | {t1/t2:>8.2f}x") # Sparse dX comparison: V1 vs V2 print(f"\n{'='*80}") print("Sparse dX (both dW+dX sparse): V1 vs V2") print(f"{'d':>5} | {'Dense':>9} | {'V1-all':>9} | {'V2-all':>9} | {'V2/Dense':>9}") print("-" * 55) for d_in in [512, 1024, 2048]: d_out = 4 * d_in; nc = d_out // cs; na = max(1, int(af * nc)) active = torch.randperm(nc, device=device)[:na].sort().values x = torch.randn(M, d_in, device=device) w = torch.randn(d_out, d_in, device=device) gy = torch.randn(M, d_out, device=device) def dense_all(): return gy.t() @ x, gy @ w def v1_all(): return v1_sparse_bwd_dW(x, gy, active, cs, d_out), v1_sparse_bwd_dX(gy, w, active, cs, M, d_in) def v2_all(): dw, _ = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=False) return dw, v2_sparse_bwd_dX(gy, w, active, cs, M, d_in) for _ in range(warmup): dense_all(); v1_all(); v2_all() torch.cuda.synchronize() for name, fn, store in [("dense", dense_all, "td"), ("v1", v1_all, "t1"), ("v2", v2_all, "t2")]: torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(iters): fn() torch.cuda.synchronize() locals()[store] = (time.perf_counter() - t0) / iters # Need to read them back since locals() trick doesn't work cleanly torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(iters): dense_all() torch.cuda.synchronize(); td = (time.perf_counter() - t0) / iters torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(iters): v1_all() torch.cuda.synchronize(); t1 = (time.perf_counter() - t0) / iters torch.cuda.synchronize(); t0 = time.perf_counter() for _ in range(iters): v2_all() torch.cuda.synchronize(); t2 = (time.perf_counter() - t0) / iters print(f"{d_in:>5} | {td*1000:>8.2f}ms | {t1*1000:>8.2f}ms | {t2*1000:>8.2f}ms | {td/t2:>8.2f}x") if __name__ == "__main__": test_correctness() benchmark()