| |
| """ |
| Sparse Transformer: Definitive Ablation Suite |
| |
| Builds on v18_fast_knn_triton.py. Addresses all three structural gaps |
| identified in the critique: |
| |
| 1. PHANTOM MOMENTUM ABLATION |
| - "phantom": standard Adam β inactive chunks' moments decay on zero grad (default) |
| - "frozen": inactive chunks' Adam state (m, v) is completely frozen |
| Compare across all schedulers to isolate whether convergence is driven |
| by the chunking algorithm or by phantom momentum acting as regularization. |
| |
| 2. COMPUTE-MATCHED BASELINES |
| - Dense at same steps (standard comparison) |
| - Dense at fewer steps matching sparse FLOPs |
| - Natively smaller dense model matching sparse active capacity |
| |
| 3. UNIFIED HARDWARE |
| Everything on CUDA (A10G). Single hardware stack. |
| |
| Plus: KNN vs EMA vs Random vs Oracle predictor comparison with proper |
| oracle overlap measurement. |
| |
| Run: |
| python ablations.py --device cuda --steps 1000 --n_embd 1024 --experiment all |
| python ablations.py --device cuda --experiment phantom_momentum |
| python ablations.py --device cuda --experiment compute_matched |
| python ablations.py --device cuda --experiment predictor_accuracy |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import sys |
| import time |
| from collections import defaultdict |
| from typing import Dict, List, Literal, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| import triton |
| import triton.language as tl |
| HAS_TRITON = True |
| except ImportError: |
| HAS_TRITON = False |
|
|
| try: |
| import tiktoken |
| HAS_TIKTOKEN = True |
| except ImportError: |
| HAS_TIKTOKEN = False |
|
|
| |
| |
| |
|
|
| if HAS_TRITON: |
| @triton.jit |
| def _sparse_bwd_dW_db_kernel( |
| X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr, |
| M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, |
| num_active: tl.constexpr, |
| stride_xm: tl.constexpr, stride_xk: tl.constexpr, |
| stride_dym: tl.constexpr, stride_dyn: tl.constexpr, |
| stride_dwn: tl.constexpr, stride_dwk: tl.constexpr, |
| HAS_BIAS: tl.constexpr, |
| CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr, |
| ): |
| cli = tl.program_id(0) |
| kbi = tl.program_id(1) |
| cidx = tl.load(chunk_ids_ptr + cli) |
| cs0 = cidx * CS |
| ko = kbi * BK |
|
|
| dy_bp = tl.make_block_ptr(dY_ptr, (d_out, M), (stride_dyn, stride_dym), |
| (cs0, 0), (CS, BM), (1, 0)) |
| x_bp = tl.make_block_ptr(X_ptr, (M, d_in), (stride_xm, stride_xk), |
| (0, ko), (BM, BK), (1, 0)) |
|
|
| acc = tl.zeros((CS, BK), dtype=tl.float32) |
| do_bias = HAS_BIAS and (kbi == 0) |
| acc_b = tl.zeros((CS,), dtype=tl.float32) |
|
|
| for _ in range(0, M, BM): |
| dy_t = tl.load(dy_bp, boundary_check=(0, 1)) |
| x = tl.load(x_bp, boundary_check=(0, 1)) |
| acc = tl.dot(dy_t, x, acc=acc) |
| if do_bias: |
| acc_b += tl.sum(dy_t, axis=1) |
| dy_bp = tl.advance(dy_bp, (0, BM)) |
| x_bp = tl.advance(x_bp, (BM, 0)) |
|
|
| dw_bp = tl.make_block_ptr(dW_ptr, (d_out, d_in), (stride_dwn, stride_dwk), |
| (cs0, ko), (CS, BK), (1, 0)) |
| tl.store(dw_bp, acc.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| if do_bias: |
| rn = cs0 + tl.arange(0, CS) |
| tl.store(dB_ptr + rn, acc_b.to(dB_ptr.dtype.element_ty), mask=rn < d_out) |
|
|
| @triton.jit |
| def _sparse_bwd_dX_kernel( |
| dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, |
| M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, |
| num_active: tl.constexpr, |
| stride_dym: tl.constexpr, stride_dyn: tl.constexpr, |
| stride_wn: tl.constexpr, stride_wk: tl.constexpr, |
| stride_dxm: tl.constexpr, stride_dxk: tl.constexpr, |
| CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr, |
| ): |
| pm = tl.program_id(0) |
| pk = tl.program_id(1) |
| mo = pm * BM |
| ko = pk * BK |
| acc = tl.zeros((BM, BK), dtype=tl.float32) |
| for i in range(0, num_active): |
| cidx = tl.load(chunk_ids_ptr + i) |
| cs0 = cidx * CS |
| dy_bp = tl.make_block_ptr(dY_ptr, (M, d_out), (stride_dym, stride_dyn), |
| (mo, cs0), (BM, CS), (1, 0)) |
| w_bp = tl.make_block_ptr(W_ptr, (d_out, d_in), (stride_wn, stride_wk), |
| (cs0, ko), (CS, BK), (1, 0)) |
| dy = tl.load(dy_bp, boundary_check=(0, 1)) |
| w = tl.load(w_bp, boundary_check=(0, 1)) |
| acc = tl.dot(dy, w, acc=acc) |
| dx_bp = tl.make_block_ptr(dX_ptr, (M, d_in), (stride_dxm, stride_dxk), |
| (mo, ko), (BM, BK), (1, 0)) |
| tl.store(dx_bp, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| def triton_bwd_dW_db(xf, gyf, active, cs, d_out, has_bias): |
| M, d_in = xf.shape |
| na = active.numel() |
| dW = torch.zeros(d_out, d_in, device=xf.device, dtype=xf.dtype) |
| dB = torch.zeros(d_out, device=xf.device, dtype=xf.dtype) if has_bias else None |
| if na == 0: return dW, dB |
| cids = active.to(torch.int32).contiguous() |
| BK, BM = 64, 64 |
| _sparse_bwd_dW_db_kernel[(na, triton.cdiv(d_in, BK))]( |
| xf, gyf, dW, dB if has_bias else dW, cids, |
| M, d_in, d_out, na, |
| xf.stride(0), xf.stride(1), gyf.stride(0), gyf.stride(1), |
| dW.stride(0), dW.stride(1), |
| HAS_BIAS=has_bias, CS=cs, BK=BK, BM=BM, num_warps=4) |
| return dW, dB |
|
|
| def triton_bwd_dX(gyf, w, active, cs, M, d_in): |
| na = active.numel() |
| d_out = gyf.shape[1] |
| dX = torch.zeros(M, d_in, device=gyf.device, dtype=gyf.dtype) |
| if na == 0: return dX |
| cids = active.to(torch.int32).contiguous() |
| BM, BK = 64, 64 |
| _sparse_bwd_dX_kernel[(triton.cdiv(M, BM), triton.cdiv(d_in, BK))]( |
| gyf, w, dX, cids, |
| M, d_in, d_out, na, |
| gyf.stride(0), gyf.stride(1), w.stride(0), w.stride(1), |
| dX.stride(0), dX.stride(1), |
| CS=cs, BM=BM, BK=BK, num_warps=4) |
| return dX |
|
|
| |
| |
| |
|
|
| class TritonSparseLinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, w, b, active, cs, sparse_dx): |
| ctx.save_for_backward(x, w, active) |
| ctx.has_bias = b is not None |
| ctx.sparse_dx = sparse_dx |
| ctx.cs = cs |
| return F.linear(x, w, b) |
|
|
| @staticmethod |
| def backward(ctx, gy): |
| x, w, active = ctx.saved_tensors |
| cs = ctx.cs |
| do, di = w.shape |
| xf = x.reshape(-1, di).contiguous() |
| gf = gy.reshape(-1, do).contiguous() |
| M = xf.shape[0] |
| gw, gb = triton_bwd_dW_db(xf, gf, active, cs, do, ctx.has_bias) |
| gx = triton_bwd_dX(gf, w.contiguous(), active, cs, M, di) if ctx.sparse_dx else gf @ w |
| return gx.reshape(x.shape), gw, gb, None, None, None |
|
|
| class PyLoopSparseLinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, w, b, active, cs, sparse_dx): |
| ctx.save_for_backward(x, w, active) |
| ctx.has_bias = b is not None |
| ctx.sparse_dx = sparse_dx |
| ctx.cs = cs |
| return F.linear(x, w, b) |
|
|
| @staticmethod |
| def backward(ctx, gy): |
| x, w, active = ctx.saved_tensors |
| cs = ctx.cs |
| xf = x.reshape(-1, x.shape[-1]) |
| gf = gy.reshape(-1, gy.shape[-1]) |
| gw = torch.zeros_like(w) |
| gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None |
| gx = torch.zeros_like(xf) if ctx.sparse_dx else gf @ w |
| for c in active.tolist(): |
| s, e = c * cs, (c+1) * cs |
| sl = gf[:, s:e] |
| gw[s:e] = sl.t() @ xf |
| if gb is not None: gb[s:e] = sl.sum(0) |
| if ctx.sparse_dx: gx += sl @ w[s:e] |
| return gx.reshape(x.shape), gw, gb, None, None, None |
|
|
| |
| |
| |
|
|
| class SparseLinear(nn.Linear): |
| def __init__(self, inf, outf, bias=True): |
| super().__init__(inf, outf, bias=bias) |
| self.sparse_enabled = False |
| self.sparse_dx = False |
| self.active_chunks = None |
| self.chunk_size = 64 |
| self.backend = "triton" |
|
|
| def forward(self, x): |
| if not self.sparse_enabled or self.active_chunks is None: |
| return F.linear(x, self.weight, self.bias) |
| fn = TritonSparseLinearFn if (self.backend == "triton" and HAS_TRITON) else PyLoopSparseLinearFn |
| return fn.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx) |
|
|
| class Attn(nn.Module): |
| def __init__(self, d, nh, bs, do): |
| super().__init__() |
| self.nh, self.hd = nh, d // nh |
| self.c_attn = SparseLinear(d, 3*d) |
| self.c_proj = SparseLinear(d, d) |
| self.drop = nn.Dropout(do) |
| self.register_buffer("mask", torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs)) |
|
|
| def forward(self, x): |
| B,T,C = x.shape |
| q,k,v = self.c_attn(x).split(C, 2) |
| q = q.view(B,T,self.nh,self.hd).transpose(1,2) |
| k = k.view(B,T,self.nh,self.hd).transpose(1,2) |
| v = v.view(B,T,self.nh,self.hd).transpose(1,2) |
| a = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd) |
| a = a.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf")) |
| a = self.drop(F.softmax(a, dim=-1)) |
| return self.c_proj((a @ v).transpose(1,2).contiguous().view(B,T,C)) |
|
|
| class FFN(nn.Module): |
| def __init__(self, d, do, ffn_mult=4): |
| super().__init__() |
| self.c_fc = SparseLinear(d, ffn_mult * d) |
| self.c_proj = SparseLinear(ffn_mult * d, d) |
| self.drop = nn.Dropout(do) |
| def forward(self, x): |
| return self.drop(self.c_proj(F.gelu(self.c_fc(x)))) |
|
|
| class Block(nn.Module): |
| def __init__(self, d, nh, bs, do, ffn_mult=4): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(d); self.attn = Attn(d, nh, bs, do) |
| self.ln2 = nn.LayerNorm(d); self.mlp = FFN(d, do, ffn_mult) |
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| return x + self.mlp(self.ln2(x)) |
|
|
| class GPT(nn.Module): |
| def __init__(self, V, bs, nl, nh, d, do, ffn_mult=4): |
| super().__init__() |
| self.te = nn.Embedding(V, d); self.pe = nn.Embedding(bs, d) |
| self.blocks = nn.Sequential(*[Block(d, nh, bs, do, ffn_mult) for _ in range(nl)]) |
| self.ln = nn.LayerNorm(d); self.head = nn.Linear(d, V) |
| def forward(self, idx, tgt=None): |
| B,T = idx.shape |
| x = self.te(idx) + self.pe(torch.arange(T, device=idx.device))[None] |
| lo = self.head(self.ln(self.blocks(x))) |
| loss = F.cross_entropy(lo.view(-1, lo.size(-1)), tgt.view(-1)) if tgt is not None else None |
| return lo, loss |
| def nparams(self): return sum(p.numel() for p in self.parameters()) |
|
|
| def get_sparse_linears(m): return [x for x in m.modules() if isinstance(x, SparseLinear)] |
|
|
| |
| |
| |
|
|
| class Corpus: |
| """Uses tiktoken GPT-2 BPE on Tiny Shakespeare if available, else char-level synthetic.""" |
| _inst = None |
| @classmethod |
| def get(cls, bs, dev): |
| if cls._inst is None or cls._inst.block_size != bs: |
| cls._inst = cls(bs, dev) |
| return cls._inst |
|
|
| def __init__(self, block_size, device): |
| self.block_size, self.device = block_size, device |
| import urllib.request |
| p = "input.txt" |
| if not os.path.exists(p): |
| urllib.request.urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", p) |
| text = open(p).read() |
| if HAS_TIKTOKEN: |
| enc = tiktoken.get_encoding("gpt2") |
| tokens = enc.encode(text) |
| self.vocab_size = enc.n_vocab |
| else: |
| chars = sorted(set(text)) |
| stoi = {c:i for i,c in enumerate(chars)} |
| tokens = [stoi[c] for c in text] |
| self.vocab_size = len(chars) |
| data = torch.tensor(tokens, dtype=torch.long) |
| si = int(0.9 * len(data)) |
| self.train_data, self.val_data = data[:si], data[si:] |
| print(f"Corpus: V={self.vocab_size}, train={len(self.train_data):,}, val={len(self.val_data):,}") |
|
|
| def get_batch(self, split, bs, gen=None): |
| d = self.train_data if split == "train" else self.val_data |
| ix = torch.randint(len(d)-self.block_size-1, (bs,), generator=gen) |
| x = torch.stack([d[i:i+self.block_size] for i in ix]) |
| y = torch.stack([d[i+1:i+self.block_size+1] for i in ix]) |
| return x.to(self.device), y.to(self.device) |
|
|
| def make_gen(s): |
| g = torch.Generator(device="cpu"); g.manual_seed(s); return g |
|
|
| |
| |
| |
|
|
| class ChunkScheduler: |
| def __init__(self, model, policy, frac, cs, dev, beta=0.95, knn_k=3, |
| sim_hist=128, min_sim_hist=8): |
| self.policy, self.frac, self.cs, self.dev = policy, frac, cs, dev |
| self.beta, self.knn_k = beta, knn_k |
| self.sim_hist, self.min_sim_hist = sim_hist, min_sim_hist |
| self.linears = get_sparse_linears(model) |
| self.m2ids, self.m2loc = {}, {} |
| off = 0 |
| for m in self.linears: |
| m.chunk_size = cs |
| nc = m.out_features // cs |
| assert m.out_features % cs == 0 |
| self.m2ids[m] = torch.arange(off, off+nc, device=dev) |
| self.m2loc[m] = torch.arange(nc, device=dev) |
| off += nc |
| self.nc = off |
| self.ema = torch.zeros(self.nc, device=dev) |
| self.active = torch.zeros(self.nc, dtype=torch.bool, device=dev) |
| self.mass_history = [] |
| self.similarity = None |
| self.scores = torch.zeros(self.nc, device=dev) |
|
|
| def get_frac(self, step, wu, an): |
| if step < wu: return 1.0 |
| if an > 0 and step < wu + an: |
| p = (step - wu) / an |
| return self.frac + (1-self.frac) * 0.5 * (1 + math.cos(math.pi * p)) |
| return self.frac |
|
|
| def choose(self, step, wu, an): |
| f = self.get_frac(step, wu, an) |
| if f >= 0.999: |
| self.active.fill_(True) |
| self._install(); return |
| k = max(1, int(f * self.nc)) |
| self.active.fill_(False) |
| if self.policy == "random": |
| idx = torch.randperm(self.nc, device=self.dev)[:k] |
| elif self.policy == "ema": |
| idx = torch.topk(self.ema + 1e-9*torch.rand_like(self.ema), k=k).indices |
| elif self.policy == "knn": |
| base = self.scores if self.scores.sum() > 1e-12 else self.ema |
| idx = torch.topk(base + 1e-9*torch.rand_like(base), k=k).indices |
| else: |
| raise ValueError(self.policy) |
| self.active[idx] = True |
| self._install() |
|
|
| def _install(self): |
| for m, gids in self.m2ids.items(): |
| m.active_chunks = self.m2loc[m][self.active[gids]] |
|
|
| @torch.no_grad() |
| def update(self, step, wu): |
| cur = torch.zeros_like(self.ema) |
| for m, ids in self.m2ids.items(): |
| if m.weight.grad is None: continue |
| s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2)) |
| if m.bias is not None and m.bias.grad is not None: |
| s += m.bias.grad.square().view(len(ids), self.cs).sum(1) |
| cur[ids] = torch.sqrt(s + 1e-30) |
| obs = self.active |
| new = obs & (self.ema == 0) |
| old = obs & ~new |
| self.ema[new] = cur[new] |
| self.ema[old] = self.beta*self.ema[old] + (1-self.beta)*cur[old] |
| |
| if step < wu: |
| self.mass_history.append(cur.clone()) |
| if len(self.mass_history) > self.sim_hist: |
| self.mass_history = self.mass_history[-self.sim_hist:] |
| if len(self.mass_history) >= self.min_sim_hist: |
| self.similarity = self._build_sim() |
| if self.policy == "knn": |
| self.scores = self._knn_scores(self.active, cur) |
| else: |
| self.scores = self.ema.clone() |
| return cur |
|
|
| def _build_sim(self): |
| H = torch.stack(self.mass_history) |
| H = (H - H.mean(0, keepdim=True)) / (H.std(0, keepdim=True) + 1e-6) |
| S = torch.clamp((H.T @ H) / max(1, H.shape[0]-1), min=0) |
| S.fill_diagonal_(0) |
| ok = torch.zeros_like(S, dtype=torch.bool) |
| for _, ids in self.m2ids.items(): |
| ok[ids[:,None], ids[None,:]] = True |
| return torch.where(ok, S, torch.zeros_like(S)) |
|
|
| def _knn_scores(self, active_mask, cur): |
| if self.similarity is None: return self.ema.clone() |
| sc = self.ema.clone() |
| sc[active_mask] = cur[active_mask] |
| aidx = active_mask.nonzero(as_tuple=False).flatten() |
| iidx = (~active_mask).nonzero(as_tuple=False).flatten() |
| if aidx.numel() == 0: return sc |
| S = self.similarity |
| for i in iidx.tolist(): |
| w = S[i, aidx] |
| if w.sum() <= 1e-12: continue |
| kk = min(self.knn_k, w.numel()) |
| top = torch.topk(w, k=kk) |
| sc[i] = (top.values * cur[aidx[top.indices]]).sum() / (top.values.sum() + 1e-12) |
| return sc |
|
|
| @torch.no_grad() |
| def oracle_scores(self): |
| """Compute dense gradient magnitudes per chunk (requires dense grads already computed).""" |
| sc = torch.zeros(self.nc, device=self.dev) |
| for m, ids in self.m2ids.items(): |
| if m.weight.grad is None: continue |
| s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2)) |
| if m.bias is not None and m.bias.grad is not None: |
| s += m.bias.grad.square().view(len(ids), self.cs).sum(1) |
| sc[ids] = torch.sqrt(s + 1e-30) |
| return sc |
|
|
| def measure_overlap(self, k): |
| """Jaccard and recall of current active vs oracle top-k.""" |
| oracle = set(torch.topk(self.oracle_scores(), k=k).indices.tolist()) |
| pred = set(self.active.nonzero(as_tuple=True)[0].tolist()) |
| if not oracle or not pred: return 0., 0. |
| inter = oracle & pred |
| return len(inter)/len(oracle|pred), len(inter)/len(oracle) |
|
|
| |
| |
| |
|
|
| class ChunkedAdam: |
| """ |
| Adam with two modes for inactive chunks: |
| phantom: standard β m,v decay even on zero grad (default, original behavior) |
| frozen: m,v state completely frozen for inactive chunks |
| """ |
| def __init__(self, model, lr=3e-4, cs=64, momentum_mode="phantom"): |
| self.model, self.lr, self.cs = model, lr, cs |
| self.momentum_mode = momentum_mode |
| self.state = {} |
| self.p2m = {} |
| for m in get_sparse_linears(model): |
| if m.weight is not None: self.p2m[m.weight] = m |
| if m.bias is not None: self.p2m[m.bias] = m |
|
|
| def zero_grad(self): |
| for p in self.model.parameters(): p.grad = None |
|
|
| @torch.no_grad() |
| def step(self): |
| for p in self.model.parameters(): |
| if p.grad is None: continue |
| if p not in self.state: |
| self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} |
| m, v = self.state[p]["m"], self.state[p]["v"] |
| sm = self.p2m.get(p) |
| ac = getattr(sm, 'active_chunks', None) if sm else None |
|
|
| if ac is None: |
| |
| m.mul_(0.9).add_(p.grad, alpha=0.1) |
| v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) |
| p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr) |
| else: |
| if self.momentum_mode == "phantom": |
| |
| |
| |
| m.mul_(0.9).add_(p.grad, alpha=0.1) |
| v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) |
| |
| for c in ac.tolist(): |
| s, e = c*self.cs, (c+1)*self.cs |
| p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr) |
| elif self.momentum_mode == "frozen": |
| |
| for c in ac.tolist(): |
| s, e = c*self.cs, (c+1)*self.cs |
| g = p.grad[s:e] |
| m[s:e].mul_(0.9).add_(g, alpha=0.1) |
| v[s:e].mul_(0.999).addcmul_(g, g, value=0.001) |
| p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, corpus, bs, n=20, seed=9999): |
| model.eval() |
| losses = [] |
| for i in range(n): |
| _, l = model(*corpus.get_batch("val", bs, make_gen(seed+i))) |
| losses.append(l.item()) |
| model.train() |
| avg = sum(losses)/len(losses) |
| return avg, math.exp(min(avg, 20)) |
|
|
| |
| |
| |
|
|
| def run(policy, bwd_mode, steps, bs, block_size, nl, nh, d, cs, |
| active_frac, wu, an, lr, device, seed, backend="triton", |
| momentum_mode="phantom", ffn_mult=4, |
| measure_oracle=False, oracle_interval=50): |
| """Run one training config. Returns dict of results.""" |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) |
| random.seed(seed) |
|
|
| corpus = Corpus.get(block_size, device) |
| model = GPT(corpus.vocab_size, block_size, nl, nh, d, 0.1, ffn_mult).to(device) |
| for m in get_sparse_linears(model): |
| m.chunk_size = cs |
| m.backend = backend |
|
|
| is_dense = (policy == "dense") |
| sched = None if is_dense else ChunkScheduler(model, policy, active_frac, cs, device) |
| opt = ChunkedAdam(model, lr=lr, cs=cs, momentum_mode=momentum_mode) |
|
|
| np_ = model.nparams() |
| overlaps = [] |
|
|
| torch.cuda.synchronize() if device == "cuda" else None |
| t0 = time.perf_counter() |
|
|
| for step in range(steps): |
| x, y = corpus.get_batch("train", bs, make_gen(step)) |
|
|
| if is_dense: |
| for m in get_sparse_linears(model): |
| m.sparse_enabled = False; m.active_chunks = None |
| else: |
| sched.choose(step, wu, an) |
| for m in get_sparse_linears(model): |
| m.sparse_enabled = True |
| m.sparse_dx = (bwd_mode == "sparse_dX") |
|
|
| opt.zero_grad() |
| _, loss = model(x, y) |
| loss.backward() |
|
|
| if sched: |
| sched.update(step, wu) |
|
|
| |
| if measure_oracle and step % oracle_interval == 0 and step >= wu + an: |
| saved = {p: p.grad.clone() for p in model.parameters() if p.grad is not None} |
| for m in get_sparse_linears(model): m.sparse_enabled = False |
| for p in model.parameters(): p.grad = None |
| _, lo = model(x, y); lo.backward() |
| k = max(1, int(active_frac * sched.nc)) |
| j, r = sched.measure_overlap(k) |
| overlaps.append((step, j, r)) |
| for p in model.parameters(): |
| if p in saved: p.grad = saved[p] |
| for m in get_sparse_linears(model): m.sparse_enabled = True |
|
|
| opt.step() |
|
|
| if step % 200 == 0: |
| print(f" step {step}/{steps} loss={loss.item():.4f}") |
|
|
| torch.cuda.synchronize() if device == "cuda" else None |
| wall = time.perf_counter() - t0 |
|
|
| for m in get_sparse_linears(model): m.sparse_enabled = False |
| vl, vp = evaluate(model, corpus, bs, n=30) |
|
|
| del model; torch.cuda.empty_cache() if device == "cuda" else None |
|
|
| return { |
| "val_loss": vl, "val_ppl": vp, "wall_time": wall, |
| "ms_per_step": 1000*wall/steps, "n_params": np_, |
| "train_loss_final": loss.item(), "overlaps": overlaps, |
| } |
|
|
| def run_seeds(cfg, seeds): |
| results = [] |
| for s in seeds: |
| cfg["seed"] = s |
| results.append(run(**cfg)) |
| vls = [r["val_loss"] for r in results] |
| ml = sum(vls)/len(vls) |
| sl = (sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5 |
| return {"mean_loss": ml, "std_loss": sl, "results": results, |
| "mean_ms": sum(r["ms_per_step"] for r in results)/len(results)} |
|
|
| |
| |
| |
|
|
| def exp_phantom_momentum(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): |
| print("\n" + "="*80) |
| print("EXPERIMENT 1: Phantom Momentum Ablation") |
| print("="*80) |
|
|
| base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, |
| nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, |
| lr=lr, device=device, backend=backend) |
|
|
| configs = [ |
| ("dense", "dense", "phantom"), |
| ("ema+phantom", "ema", "phantom"), |
| ("ema+frozen", "ema", "frozen"), |
| ("knn+phantom", "knn", "phantom"), |
| ("knn+frozen", "knn", "frozen"), |
| ("random+phantom", "random", "phantom"), |
| ("random+frozen", "random", "frozen"), |
| ] |
|
|
| results = {} |
| for name, policy, mm in configs: |
| print(f"\n--- {name} ---") |
| cfg = {**base, "policy": policy, "momentum_mode": mm} |
| results[name] = run_seeds(cfg, seeds) |
|
|
| print(f"\n{'Method':<22} | {'Val Loss':>18} | {'ms/step':>10}") |
| print("-"*55) |
| for name, _, _ in configs: |
| r = results[name] |
| print(f"{name:<22} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") |
|
|
| return results |
|
|
| |
| |
| |
|
|
| def exp_compute_matched(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): |
| print("\n" + "="*80) |
| print("EXPERIMENT 2: Compute-Matched Baselines") |
| print("="*80) |
|
|
| base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, |
| nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, |
| lr=lr, device=device, backend=backend, momentum_mode="phantom") |
|
|
| |
| print("\n--- Sparse (EMA, reference) ---") |
| sparse_r = run_seeds({**base, "policy": "ema"}, seeds) |
|
|
| |
| print("\n--- Dense (same steps) ---") |
| dense_same = run_seeds({**base, "policy": "dense"}, seeds) |
|
|
| |
| |
| ratio = (1.0 + 1.0 + af) / 3.0 |
| matched_steps = int(steps * ratio) |
| print(f"\n--- Dense (compute-matched, {matched_steps} steps) ---") |
| dense_matched = run_seeds({**base, "policy": "dense", "steps": matched_steps}, seeds) |
|
|
| |
| |
| small_ffn_mult = max(1, round(4 * af)) |
| print(f"\n--- Small dense (ffn_mult={small_ffn_mult}, capacity-matched) ---") |
| dense_small = run_seeds({**base, "policy": "dense", "ffn_mult": small_ffn_mult}, seeds) |
|
|
| results = { |
| "sparse_ema": sparse_r, |
| "dense_same_steps": dense_same, |
| f"dense_matched_{matched_steps}steps": dense_matched, |
| f"dense_small_ffn{small_ffn_mult}": dense_small, |
| } |
|
|
| print(f"\n{'Method':<35} | {'Steps':>6} | {'Params':>8} | {'Val Loss':>18} | {'ms/step':>10}") |
| print("-"*90) |
| for name, r in results.items(): |
| np_ = r["results"][0]["n_params"] |
| st = r["results"][0].get("steps", steps) if "steps" in name else steps |
| |
| print(f"{name:<35} | {st if 'matched' not in name else matched_steps:>6} | {np_/1e6:>7.1f}M | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") |
|
|
| return results |
|
|
| |
| |
| |
|
|
| def exp_predictor_accuracy(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): |
| print("\n" + "="*80) |
| print("EXPERIMENT 3: Predictor Accuracy (EMA vs KNN vs Oracle)") |
| print("="*80) |
|
|
| base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, |
| nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, |
| lr=lr, device=device, backend=backend, momentum_mode="phantom", |
| measure_oracle=True, oracle_interval=25) |
|
|
| results = {} |
| for policy in ["ema", "knn", "random"]: |
| print(f"\n--- {policy} ---") |
| results[policy] = run_seeds({**base, "policy": policy}, seeds) |
|
|
| |
| for policy in ["ema", "knn", "random"]: |
| print(f"\n{policy.upper()} predictor overlap:") |
| print(f" {'Step':>6} | {'Jaccard':>10} | {'Recall':>10}") |
| sd = defaultdict(lambda: {"j": [], "r": []}) |
| for res in results[policy]["results"]: |
| for s, j, r in res["overlaps"]: |
| sd[s]["j"].append(j); sd[s]["r"].append(r) |
| for s in sorted(sd): |
| mj = sum(sd[s]["j"])/len(sd[s]["j"]) |
| mr = sum(sd[s]["r"])/len(sd[s]["r"]) |
| print(f" {s:>6} | {mj:>10.4f} | {mr:>10.4f}") |
|
|
| print(f"\n{'Policy':<10} | {'Val Loss':>18} | {'ms/step':>10}") |
| print("-"*45) |
| for p in ["ema", "knn", "random"]: |
| r = results[p] |
| print(f"{p:<10} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") |
|
|
| return results |
|
|
| |
| |
| |
|
|
| ALL_EXPS = { |
| "phantom_momentum": exp_phantom_momentum, |
| "compute_matched": exp_compute_matched, |
| "predictor_accuracy": exp_predictor_accuracy, |
| } |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--experiment", default="all", choices=list(ALL_EXPS)+["all"]) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--steps", type=int, default=1000) |
| p.add_argument("--seeds", default="42,123,456") |
| p.add_argument("--n_embd", type=int, default=1024) |
| p.add_argument("--n_layer", type=int, default=4) |
| p.add_argument("--n_head", type=int, default=8) |
| p.add_argument("--batch_size", type=int, default=8) |
| p.add_argument("--block_size", type=int, default=256) |
| p.add_argument("--chunk_size", type=int, default=64) |
| p.add_argument("--active_fraction", type=float, default=0.10) |
| p.add_argument("--warmup_steps", type=int, default=50) |
| p.add_argument("--anneal_steps", type=int, default=200) |
| p.add_argument("--lr", type=float, default=3e-4) |
| p.add_argument("--backend", default="triton", choices=["triton", "torch"]) |
| p.add_argument("--output_dir", default="results") |
| args = p.parse_args() |
|
|
| seeds = [int(s) for s in args.seeds.split(",")] |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if args.device == "cuda" and torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") |
| print(f"Config: d={args.n_embd} nl={args.n_layer} nh={args.n_head} steps={args.steps} seeds={seeds}") |
| print(f" cs={args.chunk_size} af={args.active_fraction} backend={args.backend}") |
|
|
| shared = dict(device=args.device, steps=args.steps, seeds=seeds, |
| d=args.n_embd, nl=args.n_layer, nh=args.n_head, |
| bs=args.batch_size, block_size=args.block_size, |
| cs=args.chunk_size, af=args.active_fraction, |
| wu=args.warmup_steps, an=args.anneal_steps, |
| lr=args.lr, backend=args.backend) |
|
|
| exps = ALL_EXPS if args.experiment == "all" else {args.experiment: ALL_EXPS[args.experiment]} |
| t0 = time.time() |
|
|
| for name, fn in exps.items(): |
| print(f"\n{'#'*80}\n# {name} ({(time.time()-t0)/60:.1f}m elapsed)\n{'#'*80}") |
| sys.stdout.flush() |
| result = fn(**shared) |
|
|
| def ser(o): |
| if isinstance(o, dict): return {str(k): ser(v) for k,v in o.items()} |
| if isinstance(o, list): return [ser(x) for x in o] |
| return o |
|
|
| with open(os.path.join(args.output_dir, f"{name}.json"), "w") as f: |
| json.dump(ser(result), f, indent=2, default=str) |
| print(f"β {name} saved to {args.output_dir}/{name}.json") |
|
|
| print(f"\n{'='*80}\nALL COMPLETE in {(time.time()-t0)/60:.1f} minutes\n{'='*80}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|