Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | #!/usr/bin/env python3 | |
| """ | |
| Triton-fused Chunked Sparse Backward Pass β v2. | |
| Fixes from review: | |
| 1. Bias folded into dW kernel (kills the uncoalesced column-striding bias kernel) | |
| 2. block_ptr / TMA for dW and dX loads (hardware-accelerated 2D tile fetch) | |
| 3. No autotune (fixed config to eliminate compilation overhead + divergence risk) | |
| Benchmarks v1 (manual ptrs + separate bias) vs v2 (block_ptr + fused bias) | |
| vs Python-loop baseline vs Dense. | |
| """ | |
| import math, os, time | |
| import torch, torch.nn as nn, torch.nn.functional as F | |
| import triton, triton.language as tl | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # V2 KERNELS β block_ptr + fused bias | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Fixed tile sizes β no autotune. CS=64 means one N-block covers the whole chunk. | |
| # BM=64: token tile for the M-reduction loop | |
| # BK=64: tile along d_in | |
| # BN=64: tile along chunk (== CS for chunk_size=64, so 1 block per chunk) | |
| def _v2_sparse_bwd_dW_kernel( | |
| X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr, | |
| M, d_in, d_out, num_active, | |
| stride_xm, stride_xk, | |
| stride_dym, stride_dyn, | |
| stride_dwn, stride_dwk, | |
| HAS_BIAS: tl.constexpr, | |
| CS: tl.constexpr, | |
| BK: tl.constexpr, | |
| BM: tl.constexpr, | |
| ): | |
| """ | |
| Each program computes one [CS, BK] tile of dW for one active chunk, | |
| plus the [CS] bias slice if HAS_BIAS. | |
| Grid: (num_active, ceil(d_in / BK)) | |
| Since CS fits in one tile (CS==64, BN==CS), pid0 == chunk index directly. | |
| """ | |
| chunk_linear_id = tl.program_id(0) | |
| k_block_id = tl.program_id(1) | |
| if chunk_linear_id >= num_active: | |
| return | |
| chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id) | |
| chunk_start = chunk_idx * CS | |
| k_offset = k_block_id * BK | |
| # Block pointer for dY transposed: we want dY.T[chunk_cols, :] = shape (CS, M) | |
| # dY is (M, d_out) row-major. Transposed view: shape=(d_out, M), strides=(stride_dyn, stride_dym) | |
| dy_block_ptr = tl.make_block_ptr( | |
| base=dY_ptr, | |
| shape=(d_out, M), | |
| strides=(stride_dyn, stride_dym), | |
| offsets=(chunk_start, 0), | |
| block_shape=(CS, BM), | |
| order=(1, 0), | |
| ) | |
| # Block pointer for X: shape (M, d_in), reading (BM, BK) tiles | |
| x_block_ptr = tl.make_block_ptr( | |
| base=X_ptr, | |
| shape=(M, d_in), | |
| strides=(stride_xm, stride_xk), | |
| offsets=(0, k_offset), | |
| block_shape=(BM, BK), | |
| order=(1, 0), | |
| ) | |
| # Accumulators | |
| acc_dw = tl.zeros((CS, BK), dtype=tl.float32) | |
| # Bias accumulator: only on the first k-block to avoid redundant work | |
| compute_bias = HAS_BIAS and (k_block_id == 0) | |
| acc_db = tl.zeros((CS,), dtype=tl.float32) | |
| # Reduction over M | |
| for m_start in range(0, M, BM): | |
| dy_t = tl.load(dy_block_ptr, boundary_check=(0, 1)) # (CS, BM) | |
| x = tl.load(x_block_ptr, boundary_check=(0, 1)) # (BM, BK) | |
| # dW += dY.T @ X -> (CS, BM) @ (BM, BK) = (CS, BK) | |
| acc_dw = tl.dot(dy_t, x, acc=acc_dw) | |
| # Bias: sum over M dimension of dY chunk columns | |
| # dy_t is (CS, BM) = transposed chunk. Sum along dim=1 = sum over tokens. | |
| if compute_bias: | |
| acc_db += tl.sum(dy_t, axis=1) | |
| dy_block_ptr = tl.advance(dy_block_ptr, (0, BM)) | |
| x_block_ptr = tl.advance(x_block_ptr, (BM, 0)) | |
| # Store dW tile: dW[chunk_start:chunk_start+CS, k_offset:k_offset+BK] | |
| dw_block_ptr = tl.make_block_ptr( | |
| base=dW_ptr, | |
| shape=(d_out, d_in), | |
| strides=(stride_dwn, stride_dwk), | |
| offsets=(chunk_start, k_offset), | |
| block_shape=(CS, BK), | |
| order=(1, 0), | |
| ) | |
| tl.store(dw_block_ptr, acc_dw.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| # Store bias (only from k_block_id == 0) | |
| if compute_bias: | |
| rn = chunk_start + tl.arange(0, CS) | |
| n_mask = rn < d_out | |
| tl.store(dB_ptr + rn, acc_db.to(dB_ptr.dtype.element_ty), mask=n_mask) | |
| def v2_sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out, bias=True): | |
| """Fused dW + dBias via block_ptr kernel.""" | |
| M, d_in = X.shape | |
| num_active = active_chunks.shape[0] | |
| CS = chunk_size | |
| dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype) | |
| dB = torch.zeros(d_out, device=X.device, dtype=X.dtype) if bias else None | |
| if num_active == 0: | |
| return dW, dB | |
| chunk_ids = active_chunks.to(torch.int32).contiguous() | |
| BK = 64 | |
| BM = 64 | |
| grid = (num_active, triton.cdiv(d_in, BK)) | |
| _v2_sparse_bwd_dW_kernel[grid]( | |
| X, dY, dW, dB if bias else X, # dummy ptr if no bias | |
| chunk_ids, | |
| M, d_in, d_out, num_active, | |
| X.stride(0), X.stride(1), | |
| dY.stride(0), dY.stride(1), | |
| dW.stride(0), dW.stride(1), | |
| HAS_BIAS=bias, | |
| CS=CS, BK=BK, BM=BM, | |
| ) | |
| return dW, dB | |
| # ββ V2 dX kernel with block_ptr ββ | |
| def _v2_sparse_bwd_dX_kernel( | |
| dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, | |
| M, d_in, d_out, num_active, | |
| stride_dym, stride_dyn, | |
| stride_wn, stride_wk, | |
| stride_dxm, stride_dxk, | |
| CS: tl.constexpr, | |
| BM: tl.constexpr, | |
| BK: tl.constexpr, | |
| ): | |
| """ | |
| Each program computes one [BM, BK] tile of dX by accumulating over active chunks. | |
| Grid: (ceil(M/BM), ceil(d_in/BK)) | |
| """ | |
| pid_m = tl.program_id(0) | |
| pid_k = tl.program_id(1) | |
| m_offset = pid_m * BM | |
| k_offset = pid_k * BK | |
| acc = tl.zeros((BM, BK), dtype=tl.float32) | |
| for i in range(num_active): | |
| chunk_idx = tl.load(chunk_ids_ptr + i) | |
| chunk_start = chunk_idx * CS | |
| # dY tile: (BM, CS) at [m_offset, chunk_start] | |
| dy_block_ptr = tl.make_block_ptr( | |
| base=dY_ptr, | |
| shape=(M, d_out), | |
| strides=(stride_dym, stride_dyn), | |
| offsets=(m_offset, chunk_start), | |
| block_shape=(BM, CS), | |
| order=(1, 0), | |
| ) | |
| # W tile: (CS, BK) at [chunk_start, k_offset] | |
| w_block_ptr = tl.make_block_ptr( | |
| base=W_ptr, | |
| shape=(d_out, d_in), | |
| strides=(stride_wn, stride_wk), | |
| offsets=(chunk_start, k_offset), | |
| block_shape=(CS, BK), | |
| order=(1, 0), | |
| ) | |
| dy = tl.load(dy_block_ptr, boundary_check=(0, 1)) # (BM, CS) | |
| w = tl.load(w_block_ptr, boundary_check=(0, 1)) # (CS, BK) | |
| # dY @ W -> (BM, BK) | |
| acc = tl.dot(dy, w, acc=acc) | |
| # Store dX tile | |
| dx_block_ptr = tl.make_block_ptr( | |
| base=dX_ptr, | |
| shape=(M, d_in), | |
| strides=(stride_dxm, stride_dxk), | |
| offsets=(m_offset, k_offset), | |
| block_shape=(BM, BK), | |
| order=(1, 0), | |
| ) | |
| tl.store(dx_block_ptr, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| def v2_sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in): | |
| """Fused dX via block_ptr kernel.""" | |
| num_active = active_chunks.shape[0] | |
| d_out = dY.shape[1] | |
| CS = chunk_size | |
| dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype) | |
| if num_active == 0: | |
| return dX | |
| chunk_ids = active_chunks.to(torch.int32).contiguous() | |
| BM = 64 | |
| BK = 64 | |
| grid = (triton.cdiv(M, BM), triton.cdiv(d_in, BK)) | |
| _v2_sparse_bwd_dX_kernel[grid]( | |
| dY, W, dX, chunk_ids, | |
| M, d_in, d_out, num_active, | |
| dY.stride(0), dY.stride(1), | |
| W.stride(0), W.stride(1), | |
| dX.stride(0), dX.stride(1), | |
| CS=CS, BM=BM, BK=BK, | |
| ) | |
| return dX | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # V1 KERNELS (old, for comparison) β import from triton_sparse.py | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from triton_sparse import ( | |
| sparse_bwd_dW as v1_sparse_bwd_dW, | |
| sparse_bwd_dX as v1_sparse_bwd_dX, | |
| sparse_bwd_dbias as v1_sparse_bwd_dbias, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CORRECTNESS TEST | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_correctness(): | |
| print("V2 Correctness Tests") | |
| print("=" * 60) | |
| device = "cuda" | |
| torch.manual_seed(42) | |
| for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 64)]: | |
| M = 2048 | |
| n_chunks = d_out // cs | |
| n_active = max(1, int(0.1 * n_chunks)) | |
| active = torch.randperm(n_chunks, device=device)[:n_active].sort().values | |
| x = torch.randn(M, d_in, device=device) | |
| w = torch.randn(d_out, d_in, device=device) | |
| gy = torch.randn(M, d_out, device=device) | |
| # Reference | |
| ref_dw = torch.zeros_like(w) | |
| ref_db = torch.zeros(d_out, device=device) | |
| for c in active.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| ref_dw[s:e] = gy[:, s:e].t() @ x | |
| ref_db[s:e] = gy[:, s:e].sum(0) | |
| ref_dx = torch.zeros_like(x) | |
| for c in active.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| ref_dx += gy[:, s:e] @ w[s:e] | |
| # V2 | |
| v2_dw, v2_db = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=True) | |
| v2_dx = v2_sparse_bwd_dX(gy, w, active, cs, M, d_in) | |
| dw_err = (v2_dw - ref_dw).abs().max().item() | |
| db_err = (v2_db - ref_db).abs().max().item() | |
| dx_err = (v2_dx - ref_dx).abs().max().item() | |
| ok = dw_err < 0.01 and db_err < 0.01 and dx_err < 0.01 | |
| print(f" {'β' if ok else 'β'} d_in={d_in} d_out={d_out} cs={cs}: dW={dw_err:.6f} dB={db_err:.6f} dX={dx_err:.6f}") | |
| print() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BENCHMARK | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def benchmark(): | |
| print("=" * 100) | |
| print("BENCHMARK: Dense vs PyLoop vs V1-Triton vs V2-Triton (block_ptr + fused bias)") | |
| print("=" * 100) | |
| device = "cuda" | |
| B, T = 8, 256 | |
| M = B * T | |
| cs = 64 | |
| af = 0.10 | |
| warmup = 20 | |
| iters = 100 | |
| print(f"\nM={M}, chunk_size={cs}, active_frac={af}, {iters} iters after {warmup} warmup") | |
| print(f"{'d':>5} | {'ffn':>5} | {'act':>3} | {'Dense':>9} | {'PyLoop':>9} | {'V1-Tri':>9} | {'V2-Tri':>9} | {'V2/Dense':>9} | {'V2/PyLoop':>9} | {'V2/V1':>9}") | |
| print("-" * 105) | |
| for d_in in [256, 512, 768, 1024, 1536, 2048]: | |
| d_out = 4 * d_in | |
| nc = d_out // cs | |
| na = max(1, int(af * nc)) | |
| active = torch.randperm(nc, device=device)[:na].sort().values | |
| x = torch.randn(M, d_in, device=device) | |
| w = torch.randn(d_out, d_in, device=device) | |
| gy = torch.randn(M, d_out, device=device) | |
| def dense(): | |
| return gy.t() @ x, gy @ w, gy.sum(0) | |
| def pyloop(): | |
| dw = torch.zeros_like(w); db = torch.zeros(d_out, device=device) | |
| dx = gy @ w | |
| for c in active.tolist(): | |
| s, e = c*cs, (c+1)*cs | |
| dw[s:e] = gy[:, s:e].t() @ x | |
| db[s:e] = gy[:, s:e].sum(0) | |
| return dw, dx, db | |
| def v1_tri(): | |
| dw = v1_sparse_bwd_dW(x, gy, active, cs, d_out) | |
| dx = gy @ w | |
| db = v1_sparse_bwd_dbias(gy, active, cs, d_out) | |
| return dw, dx, db | |
| def v2_tri(): | |
| dw, db = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=True) | |
| dx = gy @ w | |
| return dw, dx, db | |
| # Warmup all | |
| for _ in range(warmup): | |
| dense(); pyloop(); v1_tri(); v2_tri() | |
| torch.cuda.synchronize() | |
| times = {} | |
| for name, fn in [("dense", dense), ("pyloop", pyloop), ("v1", v1_tri), ("v2", v2_tri)]: | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(iters): fn() | |
| torch.cuda.synchronize() | |
| times[name] = (time.perf_counter() - t0) / iters | |
| td, tp, t1, t2 = times["dense"], times["pyloop"], times["v1"], times["v2"] | |
| print(f"{d_in:>5} | {d_out:>5} | {na:>3} | {td*1000:>8.2f}ms | {tp*1000:>8.2f}ms | {t1*1000:>8.2f}ms | {t2*1000:>8.2f}ms | {td/t2:>8.2f}x | {tp/t2:>8.2f}x | {t1/t2:>8.2f}x") | |
| # Sparse dX comparison: V1 vs V2 | |
| print(f"\n{'='*80}") | |
| print("Sparse dX (both dW+dX sparse): V1 vs V2") | |
| print(f"{'d':>5} | {'Dense':>9} | {'V1-all':>9} | {'V2-all':>9} | {'V2/Dense':>9}") | |
| print("-" * 55) | |
| for d_in in [512, 1024, 2048]: | |
| d_out = 4 * d_in; nc = d_out // cs; na = max(1, int(af * nc)) | |
| active = torch.randperm(nc, device=device)[:na].sort().values | |
| x = torch.randn(M, d_in, device=device) | |
| w = torch.randn(d_out, d_in, device=device) | |
| gy = torch.randn(M, d_out, device=device) | |
| def dense_all(): | |
| return gy.t() @ x, gy @ w | |
| def v1_all(): | |
| return v1_sparse_bwd_dW(x, gy, active, cs, d_out), v1_sparse_bwd_dX(gy, w, active, cs, M, d_in) | |
| def v2_all(): | |
| dw, _ = v2_sparse_bwd_dW(x, gy, active, cs, d_out, bias=False) | |
| return dw, v2_sparse_bwd_dX(gy, w, active, cs, M, d_in) | |
| for _ in range(warmup): dense_all(); v1_all(); v2_all() | |
| torch.cuda.synchronize() | |
| for name, fn, store in [("dense", dense_all, "td"), ("v1", v1_all, "t1"), ("v2", v2_all, "t2")]: | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(iters): fn() | |
| torch.cuda.synchronize() | |
| locals()[store] = (time.perf_counter() - t0) / iters | |
| # Need to read them back since locals() trick doesn't work cleanly | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(iters): dense_all() | |
| torch.cuda.synchronize(); td = (time.perf_counter() - t0) / iters | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(iters): v1_all() | |
| torch.cuda.synchronize(); t1 = (time.perf_counter() - t0) / iters | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(iters): v2_all() | |
| torch.cuda.synchronize(); t2 = (time.perf_counter() - t0) / iters | |
| print(f"{d_in:>5} | {td*1000:>8.2f}ms | {t1*1000:>8.2f}ms | {t2*1000:>8.2f}ms | {td/t2:>8.2f}x") | |
| if __name__ == "__main__": | |
| test_correctness() | |
| benchmark() | |