#!/usr/bin/env python3 """ Sparse Transformer: Definitive Ablation Suite Builds on v18_fast_knn_triton.py. Addresses all three structural gaps identified in the critique: 1. PHANTOM MOMENTUM ABLATION - "phantom": standard Adam — inactive chunks' moments decay on zero grad (default) - "frozen": inactive chunks' Adam state (m, v) is completely frozen Compare across all schedulers to isolate whether convergence is driven by the chunking algorithm or by phantom momentum acting as regularization. 2. COMPUTE-MATCHED BASELINES - Dense at same steps (standard comparison) - Dense at fewer steps matching sparse FLOPs - Natively smaller dense model matching sparse active capacity 3. UNIFIED HARDWARE Everything on CUDA (A10G). Single hardware stack. Plus: KNN vs EMA vs Random vs Oracle predictor comparison with proper oracle overlap measurement. Run: python ablations.py --device cuda --steps 1000 --n_embd 1024 --experiment all python ablations.py --device cuda --experiment phantom_momentum python ablations.py --device cuda --experiment compute_matched python ablations.py --device cuda --experiment predictor_accuracy """ from __future__ import annotations import argparse import json import math import os import random import sys import time from collections import defaultdict from typing import Dict, List, Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F try: import triton import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False try: import tiktoken HAS_TIKTOKEN = True except ImportError: HAS_TIKTOKEN = False # ═══════════════════════════════════════════════════════════════ # TRITON KERNELS (from v18_triton, no autotune, block_ptr) # ═══════════════════════════════════════════════════════════════ if HAS_TRITON: @triton.jit def _sparse_bwd_dW_db_kernel( X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr, M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr, stride_xm: tl.constexpr, stride_xk: tl.constexpr, stride_dym: tl.constexpr, stride_dyn: tl.constexpr, stride_dwn: tl.constexpr, stride_dwk: tl.constexpr, HAS_BIAS: tl.constexpr, CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr, ): cli = tl.program_id(0) kbi = tl.program_id(1) cidx = tl.load(chunk_ids_ptr + cli) cs0 = cidx * CS ko = kbi * BK dy_bp = tl.make_block_ptr(dY_ptr, (d_out, M), (stride_dyn, stride_dym), (cs0, 0), (CS, BM), (1, 0)) x_bp = tl.make_block_ptr(X_ptr, (M, d_in), (stride_xm, stride_xk), (0, ko), (BM, BK), (1, 0)) acc = tl.zeros((CS, BK), dtype=tl.float32) do_bias = HAS_BIAS and (kbi == 0) acc_b = tl.zeros((CS,), dtype=tl.float32) for _ in range(0, M, BM): dy_t = tl.load(dy_bp, boundary_check=(0, 1)) x = tl.load(x_bp, boundary_check=(0, 1)) acc = tl.dot(dy_t, x, acc=acc) if do_bias: acc_b += tl.sum(dy_t, axis=1) dy_bp = tl.advance(dy_bp, (0, BM)) x_bp = tl.advance(x_bp, (BM, 0)) dw_bp = tl.make_block_ptr(dW_ptr, (d_out, d_in), (stride_dwn, stride_dwk), (cs0, ko), (CS, BK), (1, 0)) tl.store(dw_bp, acc.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1)) if do_bias: rn = cs0 + tl.arange(0, CS) tl.store(dB_ptr + rn, acc_b.to(dB_ptr.dtype.element_ty), mask=rn < d_out) @triton.jit def _sparse_bwd_dX_kernel( dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr, stride_dym: tl.constexpr, stride_dyn: tl.constexpr, stride_wn: tl.constexpr, stride_wk: tl.constexpr, stride_dxm: tl.constexpr, stride_dxk: tl.constexpr, CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr, ): pm = tl.program_id(0) pk = tl.program_id(1) mo = pm * BM ko = pk * BK acc = tl.zeros((BM, BK), dtype=tl.float32) for i in range(0, num_active): cidx = tl.load(chunk_ids_ptr + i) cs0 = cidx * CS dy_bp = tl.make_block_ptr(dY_ptr, (M, d_out), (stride_dym, stride_dyn), (mo, cs0), (BM, CS), (1, 0)) w_bp = tl.make_block_ptr(W_ptr, (d_out, d_in), (stride_wn, stride_wk), (cs0, ko), (CS, BK), (1, 0)) dy = tl.load(dy_bp, boundary_check=(0, 1)) w = tl.load(w_bp, boundary_check=(0, 1)) acc = tl.dot(dy, w, acc=acc) dx_bp = tl.make_block_ptr(dX_ptr, (M, d_in), (stride_dxm, stride_dxk), (mo, ko), (BM, BK), (1, 0)) tl.store(dx_bp, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1)) def triton_bwd_dW_db(xf, gyf, active, cs, d_out, has_bias): M, d_in = xf.shape na = active.numel() dW = torch.zeros(d_out, d_in, device=xf.device, dtype=xf.dtype) dB = torch.zeros(d_out, device=xf.device, dtype=xf.dtype) if has_bias else None if na == 0: return dW, dB cids = active.to(torch.int32).contiguous() BK, BM = 64, 64 _sparse_bwd_dW_db_kernel[(na, triton.cdiv(d_in, BK))]( xf, gyf, dW, dB if has_bias else dW, cids, M, d_in, d_out, na, xf.stride(0), xf.stride(1), gyf.stride(0), gyf.stride(1), dW.stride(0), dW.stride(1), HAS_BIAS=has_bias, CS=cs, BK=BK, BM=BM, num_warps=4) return dW, dB def triton_bwd_dX(gyf, w, active, cs, M, d_in): na = active.numel() d_out = gyf.shape[1] dX = torch.zeros(M, d_in, device=gyf.device, dtype=gyf.dtype) if na == 0: return dX cids = active.to(torch.int32).contiguous() BM, BK = 64, 64 _sparse_bwd_dX_kernel[(triton.cdiv(M, BM), triton.cdiv(d_in, BK))]( gyf, w, dX, cids, M, d_in, d_out, na, gyf.stride(0), gyf.stride(1), w.stride(0), w.stride(1), dX.stride(0), dX.stride(1), CS=cs, BM=BM, BK=BK, num_warps=4) return dX # ═══════════════════════════════════════════════════════════════ # AUTOGRAD # ═══════════════════════════════════════════════════════════════ class TritonSparseLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, w, b, active, cs, sparse_dx): ctx.save_for_backward(x, w, active) ctx.has_bias = b is not None ctx.sparse_dx = sparse_dx ctx.cs = cs return F.linear(x, w, b) @staticmethod def backward(ctx, gy): x, w, active = ctx.saved_tensors cs = ctx.cs do, di = w.shape xf = x.reshape(-1, di).contiguous() gf = gy.reshape(-1, do).contiguous() M = xf.shape[0] gw, gb = triton_bwd_dW_db(xf, gf, active, cs, do, ctx.has_bias) gx = triton_bwd_dX(gf, w.contiguous(), active, cs, M, di) if ctx.sparse_dx else gf @ w return gx.reshape(x.shape), gw, gb, None, None, None class PyLoopSparseLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, w, b, active, cs, sparse_dx): ctx.save_for_backward(x, w, active) ctx.has_bias = b is not None ctx.sparse_dx = sparse_dx ctx.cs = cs return F.linear(x, w, b) @staticmethod def backward(ctx, gy): x, w, active = ctx.saved_tensors cs = ctx.cs xf = x.reshape(-1, x.shape[-1]) gf = gy.reshape(-1, gy.shape[-1]) gw = torch.zeros_like(w) gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None gx = torch.zeros_like(xf) if ctx.sparse_dx else gf @ w for c in active.tolist(): s, e = c * cs, (c+1) * cs sl = gf[:, s:e] gw[s:e] = sl.t() @ xf if gb is not None: gb[s:e] = sl.sum(0) if ctx.sparse_dx: gx += sl @ w[s:e] return gx.reshape(x.shape), gw, gb, None, None, None # ═══════════════════════════════════════════════════════════════ # MODEL # ═══════════════════════════════════════════════════════════════ class SparseLinear(nn.Linear): def __init__(self, inf, outf, bias=True): super().__init__(inf, outf, bias=bias) self.sparse_enabled = False self.sparse_dx = False self.active_chunks = None self.chunk_size = 64 self.backend = "triton" # "triton" or "torch" def forward(self, x): if not self.sparse_enabled or self.active_chunks is None: return F.linear(x, self.weight, self.bias) fn = TritonSparseLinearFn if (self.backend == "triton" and HAS_TRITON) else PyLoopSparseLinearFn return fn.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx) class Attn(nn.Module): def __init__(self, d, nh, bs, do): super().__init__() self.nh, self.hd = nh, d // nh self.c_attn = SparseLinear(d, 3*d) self.c_proj = SparseLinear(d, d) self.drop = nn.Dropout(do) self.register_buffer("mask", torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs)) def forward(self, x): B,T,C = x.shape q,k,v = self.c_attn(x).split(C, 2) q = q.view(B,T,self.nh,self.hd).transpose(1,2) k = k.view(B,T,self.nh,self.hd).transpose(1,2) v = v.view(B,T,self.nh,self.hd).transpose(1,2) a = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd) a = a.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf")) a = self.drop(F.softmax(a, dim=-1)) return self.c_proj((a @ v).transpose(1,2).contiguous().view(B,T,C)) class FFN(nn.Module): def __init__(self, d, do, ffn_mult=4): super().__init__() self.c_fc = SparseLinear(d, ffn_mult * d) self.c_proj = SparseLinear(ffn_mult * d, d) self.drop = nn.Dropout(do) def forward(self, x): return self.drop(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, d, nh, bs, do, ffn_mult=4): super().__init__() self.ln1 = nn.LayerNorm(d); self.attn = Attn(d, nh, bs, do) self.ln2 = nn.LayerNorm(d); self.mlp = FFN(d, do, ffn_mult) def forward(self, x): x = x + self.attn(self.ln1(x)) return x + self.mlp(self.ln2(x)) class GPT(nn.Module): def __init__(self, V, bs, nl, nh, d, do, ffn_mult=4): super().__init__() self.te = nn.Embedding(V, d); self.pe = nn.Embedding(bs, d) self.blocks = nn.Sequential(*[Block(d, nh, bs, do, ffn_mult) for _ in range(nl)]) self.ln = nn.LayerNorm(d); self.head = nn.Linear(d, V) def forward(self, idx, tgt=None): B,T = idx.shape x = self.te(idx) + self.pe(torch.arange(T, device=idx.device))[None] lo = self.head(self.ln(self.blocks(x))) loss = F.cross_entropy(lo.view(-1, lo.size(-1)), tgt.view(-1)) if tgt is not None else None return lo, loss def nparams(self): return sum(p.numel() for p in self.parameters()) def get_sparse_linears(m): return [x for x in m.modules() if isinstance(x, SparseLinear)] # ═══════════════════════════════════════════════════════════════ # DATA # ═══════════════════════════════════════════════════════════════ class Corpus: """Uses tiktoken GPT-2 BPE on Tiny Shakespeare if available, else char-level synthetic.""" _inst = None @classmethod def get(cls, bs, dev): if cls._inst is None or cls._inst.block_size != bs: cls._inst = cls(bs, dev) return cls._inst def __init__(self, block_size, device): self.block_size, self.device = block_size, device import urllib.request p = "input.txt" if not os.path.exists(p): urllib.request.urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", p) text = open(p).read() if HAS_TIKTOKEN: enc = tiktoken.get_encoding("gpt2") tokens = enc.encode(text) self.vocab_size = enc.n_vocab else: chars = sorted(set(text)) stoi = {c:i for i,c in enumerate(chars)} tokens = [stoi[c] for c in text] self.vocab_size = len(chars) data = torch.tensor(tokens, dtype=torch.long) si = int(0.9 * len(data)) self.train_data, self.val_data = data[:si], data[si:] print(f"Corpus: V={self.vocab_size}, train={len(self.train_data):,}, val={len(self.val_data):,}") def get_batch(self, split, bs, gen=None): d = self.train_data if split == "train" else self.val_data ix = torch.randint(len(d)-self.block_size-1, (bs,), generator=gen) x = torch.stack([d[i:i+self.block_size] for i in ix]) y = torch.stack([d[i+1:i+self.block_size+1] for i in ix]) return x.to(self.device), y.to(self.device) def make_gen(s): g = torch.Generator(device="cpu"); g.manual_seed(s); return g # ═══════════════════════════════════════════════════════════════ # SCHEDULER (from v18, with KNN) # ═══════════════════════════════════════════════════════════════ class ChunkScheduler: def __init__(self, model, policy, frac, cs, dev, beta=0.95, knn_k=3, sim_hist=128, min_sim_hist=8): self.policy, self.frac, self.cs, self.dev = policy, frac, cs, dev self.beta, self.knn_k = beta, knn_k self.sim_hist, self.min_sim_hist = sim_hist, min_sim_hist self.linears = get_sparse_linears(model) self.m2ids, self.m2loc = {}, {} off = 0 for m in self.linears: m.chunk_size = cs nc = m.out_features // cs assert m.out_features % cs == 0 self.m2ids[m] = torch.arange(off, off+nc, device=dev) self.m2loc[m] = torch.arange(nc, device=dev) off += nc self.nc = off self.ema = torch.zeros(self.nc, device=dev) self.active = torch.zeros(self.nc, dtype=torch.bool, device=dev) self.mass_history = [] self.similarity = None self.scores = torch.zeros(self.nc, device=dev) def get_frac(self, step, wu, an): if step < wu: return 1.0 if an > 0 and step < wu + an: p = (step - wu) / an return self.frac + (1-self.frac) * 0.5 * (1 + math.cos(math.pi * p)) return self.frac def choose(self, step, wu, an): f = self.get_frac(step, wu, an) if f >= 0.999: self.active.fill_(True) self._install(); return k = max(1, int(f * self.nc)) self.active.fill_(False) if self.policy == "random": idx = torch.randperm(self.nc, device=self.dev)[:k] elif self.policy == "ema": idx = torch.topk(self.ema + 1e-9*torch.rand_like(self.ema), k=k).indices elif self.policy == "knn": base = self.scores if self.scores.sum() > 1e-12 else self.ema idx = torch.topk(base + 1e-9*torch.rand_like(base), k=k).indices else: raise ValueError(self.policy) self.active[idx] = True self._install() def _install(self): for m, gids in self.m2ids.items(): m.active_chunks = self.m2loc[m][self.active[gids]] @torch.no_grad() def update(self, step, wu): cur = torch.zeros_like(self.ema) for m, ids in self.m2ids.items(): if m.weight.grad is None: continue s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2)) if m.bias is not None and m.bias.grad is not None: s += m.bias.grad.square().view(len(ids), self.cs).sum(1) cur[ids] = torch.sqrt(s + 1e-30) obs = self.active new = obs & (self.ema == 0) old = obs & ~new self.ema[new] = cur[new] self.ema[old] = self.beta*self.ema[old] + (1-self.beta)*cur[old] # KNN similarity building during warmup if step < wu: self.mass_history.append(cur.clone()) if len(self.mass_history) > self.sim_hist: self.mass_history = self.mass_history[-self.sim_hist:] if len(self.mass_history) >= self.min_sim_hist: self.similarity = self._build_sim() if self.policy == "knn": self.scores = self._knn_scores(self.active, cur) else: self.scores = self.ema.clone() return cur def _build_sim(self): H = torch.stack(self.mass_history) H = (H - H.mean(0, keepdim=True)) / (H.std(0, keepdim=True) + 1e-6) S = torch.clamp((H.T @ H) / max(1, H.shape[0]-1), min=0) S.fill_diagonal_(0) ok = torch.zeros_like(S, dtype=torch.bool) for _, ids in self.m2ids.items(): ok[ids[:,None], ids[None,:]] = True return torch.where(ok, S, torch.zeros_like(S)) def _knn_scores(self, active_mask, cur): if self.similarity is None: return self.ema.clone() sc = self.ema.clone() sc[active_mask] = cur[active_mask] aidx = active_mask.nonzero(as_tuple=False).flatten() iidx = (~active_mask).nonzero(as_tuple=False).flatten() if aidx.numel() == 0: return sc S = self.similarity for i in iidx.tolist(): w = S[i, aidx] if w.sum() <= 1e-12: continue kk = min(self.knn_k, w.numel()) top = torch.topk(w, k=kk) sc[i] = (top.values * cur[aidx[top.indices]]).sum() / (top.values.sum() + 1e-12) return sc @torch.no_grad() def oracle_scores(self): """Compute dense gradient magnitudes per chunk (requires dense grads already computed).""" sc = torch.zeros(self.nc, device=self.dev) for m, ids in self.m2ids.items(): if m.weight.grad is None: continue s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2)) if m.bias is not None and m.bias.grad is not None: s += m.bias.grad.square().view(len(ids), self.cs).sum(1) sc[ids] = torch.sqrt(s + 1e-30) return sc def measure_overlap(self, k): """Jaccard and recall of current active vs oracle top-k.""" oracle = set(torch.topk(self.oracle_scores(), k=k).indices.tolist()) pred = set(self.active.nonzero(as_tuple=True)[0].tolist()) if not oracle or not pred: return 0., 0. inter = oracle & pred return len(inter)/len(oracle|pred), len(inter)/len(oracle) # ═══════════════════════════════════════════════════════════════ # CHUNKED ADAM WITH PHANTOM/FROZEN MODES # ═══════════════════════════════════════════════════════════════ class ChunkedAdam: """ Adam with two modes for inactive chunks: phantom: standard — m,v decay even on zero grad (default, original behavior) frozen: m,v state completely frozen for inactive chunks """ def __init__(self, model, lr=3e-4, cs=64, momentum_mode="phantom"): self.model, self.lr, self.cs = model, lr, cs self.momentum_mode = momentum_mode # "phantom" or "frozen" self.state = {} self.p2m = {} for m in get_sparse_linears(model): if m.weight is not None: self.p2m[m.weight] = m if m.bias is not None: self.p2m[m.bias] = m def zero_grad(self): for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self): for p in self.model.parameters(): if p.grad is None: continue if p not in self.state: self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} m, v = self.state[p]["m"], self.state[p]["v"] sm = self.p2m.get(p) ac = getattr(sm, 'active_chunks', None) if sm else None if ac is None: # Dense parameter (LN, embeddings, lm_head) — always full update m.mul_(0.9).add_(p.grad, alpha=0.1) v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr) else: if self.momentum_mode == "phantom": # PHANTOM: update ALL chunks' moments, but only active get real gradients. # Inactive chunks see grad=0, so m decays and v decays. # This is the original behavior. m.mul_(0.9).add_(p.grad, alpha=0.1) v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) # But only update weights for active chunks for c in ac.tolist(): s, e = c*self.cs, (c+1)*self.cs p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr) elif self.momentum_mode == "frozen": # FROZEN: only touch m,v,p for active chunks. Inactive state is untouched. for c in ac.tolist(): s, e = c*self.cs, (c+1)*self.cs g = p.grad[s:e] m[s:e].mul_(0.9).add_(g, alpha=0.1) v[s:e].mul_(0.999).addcmul_(g, g, value=0.001) p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr) # ═══════════════════════════════════════════════════════════════ # EVALUATION # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def evaluate(model, corpus, bs, n=20, seed=9999): model.eval() losses = [] for i in range(n): _, l = model(*corpus.get_batch("val", bs, make_gen(seed+i))) losses.append(l.item()) model.train() avg = sum(losses)/len(losses) return avg, math.exp(min(avg, 20)) # ═══════════════════════════════════════════════════════════════ # SINGLE TRAINING RUN # ═══════════════════════════════════════════════════════════════ def run(policy, bwd_mode, steps, bs, block_size, nl, nh, d, cs, active_frac, wu, an, lr, device, seed, backend="triton", momentum_mode="phantom", ffn_mult=4, measure_oracle=False, oracle_interval=50): """Run one training config. Returns dict of results.""" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) random.seed(seed) corpus = Corpus.get(block_size, device) model = GPT(corpus.vocab_size, block_size, nl, nh, d, 0.1, ffn_mult).to(device) for m in get_sparse_linears(model): m.chunk_size = cs m.backend = backend is_dense = (policy == "dense") sched = None if is_dense else ChunkScheduler(model, policy, active_frac, cs, device) opt = ChunkedAdam(model, lr=lr, cs=cs, momentum_mode=momentum_mode) np_ = model.nparams() overlaps = [] torch.cuda.synchronize() if device == "cuda" else None t0 = time.perf_counter() for step in range(steps): x, y = corpus.get_batch("train", bs, make_gen(step)) if is_dense: for m in get_sparse_linears(model): m.sparse_enabled = False; m.active_chunks = None else: sched.choose(step, wu, an) for m in get_sparse_linears(model): m.sparse_enabled = True m.sparse_dx = (bwd_mode == "sparse_dX") opt.zero_grad() _, loss = model(x, y) loss.backward() if sched: sched.update(step, wu) # Oracle overlap measurement if measure_oracle and step % oracle_interval == 0 and step >= wu + an: saved = {p: p.grad.clone() for p in model.parameters() if p.grad is not None} for m in get_sparse_linears(model): m.sparse_enabled = False for p in model.parameters(): p.grad = None _, lo = model(x, y); lo.backward() k = max(1, int(active_frac * sched.nc)) j, r = sched.measure_overlap(k) overlaps.append((step, j, r)) for p in model.parameters(): if p in saved: p.grad = saved[p] for m in get_sparse_linears(model): m.sparse_enabled = True opt.step() if step % 200 == 0: print(f" step {step}/{steps} loss={loss.item():.4f}") torch.cuda.synchronize() if device == "cuda" else None wall = time.perf_counter() - t0 for m in get_sparse_linears(model): m.sparse_enabled = False vl, vp = evaluate(model, corpus, bs, n=30) del model; torch.cuda.empty_cache() if device == "cuda" else None return { "val_loss": vl, "val_ppl": vp, "wall_time": wall, "ms_per_step": 1000*wall/steps, "n_params": np_, "train_loss_final": loss.item(), "overlaps": overlaps, } def run_seeds(cfg, seeds): results = [] for s in seeds: cfg["seed"] = s results.append(run(**cfg)) vls = [r["val_loss"] for r in results] ml = sum(vls)/len(vls) sl = (sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5 return {"mean_loss": ml, "std_loss": sl, "results": results, "mean_ms": sum(r["ms_per_step"] for r in results)/len(results)} # ═══════════════════════════════════════════════════════════════ # EXPERIMENT 1: PHANTOM MOMENTUM ABLATION # ═══════════════════════════════════════════════════════════════ def exp_phantom_momentum(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): print("\n" + "="*80) print("EXPERIMENT 1: Phantom Momentum Ablation") print("="*80) base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, lr=lr, device=device, backend=backend) configs = [ ("dense", "dense", "phantom"), ("ema+phantom", "ema", "phantom"), ("ema+frozen", "ema", "frozen"), ("knn+phantom", "knn", "phantom"), ("knn+frozen", "knn", "frozen"), ("random+phantom", "random", "phantom"), ("random+frozen", "random", "frozen"), ] results = {} for name, policy, mm in configs: print(f"\n--- {name} ---") cfg = {**base, "policy": policy, "momentum_mode": mm} results[name] = run_seeds(cfg, seeds) print(f"\n{'Method':<22} | {'Val Loss':>18} | {'ms/step':>10}") print("-"*55) for name, _, _ in configs: r = results[name] print(f"{name:<22} | {r['mean_loss']:.4f} ± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") return results # ═══════════════════════════════════════════════════════════════ # EXPERIMENT 2: COMPUTE-MATCHED BASELINES # ═══════════════════════════════════════════════════════════════ def exp_compute_matched(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): print("\n" + "="*80) print("EXPERIMENT 2: Compute-Matched Baselines") print("="*80) base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, lr=lr, device=device, backend=backend, momentum_mode="phantom") # 1. Sparse reference print("\n--- Sparse (EMA, reference) ---") sparse_r = run_seeds({**base, "policy": "ema"}, seeds) # 2. Dense at same steps print("\n--- Dense (same steps) ---") dense_same = run_seeds({**base, "policy": "dense"}, seeds) # 3. Dense at compute-matched steps # Sparse does ~70% of dense FLOPs (fwd dense + dX dense + dW at 10%) ratio = (1.0 + 1.0 + af) / 3.0 matched_steps = int(steps * ratio) print(f"\n--- Dense (compute-matched, {matched_steps} steps) ---") dense_matched = run_seeds({**base, "policy": "dense", "steps": matched_steps}, seeds) # 4. Natively smaller dense model: FFN multiplier = 4 * af = 0.4 (rounded) # This gives a model with ~10% of the FFN capacity small_ffn_mult = max(1, round(4 * af)) # 4*0.1 = 0.4, round to 1 print(f"\n--- Small dense (ffn_mult={small_ffn_mult}, capacity-matched) ---") dense_small = run_seeds({**base, "policy": "dense", "ffn_mult": small_ffn_mult}, seeds) results = { "sparse_ema": sparse_r, "dense_same_steps": dense_same, f"dense_matched_{matched_steps}steps": dense_matched, f"dense_small_ffn{small_ffn_mult}": dense_small, } print(f"\n{'Method':<35} | {'Steps':>6} | {'Params':>8} | {'Val Loss':>18} | {'ms/step':>10}") print("-"*90) for name, r in results.items(): np_ = r["results"][0]["n_params"] st = r["results"][0].get("steps", steps) if "steps" in name else steps # read actual steps from config — approximate print(f"{name:<35} | {st if 'matched' not in name else matched_steps:>6} | {np_/1e6:>7.1f}M | {r['mean_loss']:.4f} ± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") return results # ═══════════════════════════════════════════════════════════════ # EXPERIMENT 3: PREDICTOR ACCURACY (EMA vs KNN vs Oracle) # ═══════════════════════════════════════════════════════════════ def exp_predictor_accuracy(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend): print("\n" + "="*80) print("EXPERIMENT 3: Predictor Accuracy (EMA vs KNN vs Oracle)") print("="*80) base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size, nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an, lr=lr, device=device, backend=backend, momentum_mode="phantom", measure_oracle=True, oracle_interval=25) results = {} for policy in ["ema", "knn", "random"]: print(f"\n--- {policy} ---") results[policy] = run_seeds({**base, "policy": policy}, seeds) # Aggregate overlaps for policy in ["ema", "knn", "random"]: print(f"\n{policy.upper()} predictor overlap:") print(f" {'Step':>6} | {'Jaccard':>10} | {'Recall':>10}") sd = defaultdict(lambda: {"j": [], "r": []}) for res in results[policy]["results"]: for s, j, r in res["overlaps"]: sd[s]["j"].append(j); sd[s]["r"].append(r) for s in sorted(sd): mj = sum(sd[s]["j"])/len(sd[s]["j"]) mr = sum(sd[s]["r"])/len(sd[s]["r"]) print(f" {s:>6} | {mj:>10.4f} | {mr:>10.4f}") print(f"\n{'Policy':<10} | {'Val Loss':>18} | {'ms/step':>10}") print("-"*45) for p in ["ema", "knn", "random"]: r = results[p] print(f"{p:<10} | {r['mean_loss']:.4f} ± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}") return results # ═══════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════ ALL_EXPS = { "phantom_momentum": exp_phantom_momentum, "compute_matched": exp_compute_matched, "predictor_accuracy": exp_predictor_accuracy, } def main(): p = argparse.ArgumentParser() p.add_argument("--experiment", default="all", choices=list(ALL_EXPS)+["all"]) p.add_argument("--device", default="cuda") p.add_argument("--steps", type=int, default=1000) p.add_argument("--seeds", default="42,123,456") p.add_argument("--n_embd", type=int, default=1024) p.add_argument("--n_layer", type=int, default=4) p.add_argument("--n_head", type=int, default=8) p.add_argument("--batch_size", type=int, default=8) p.add_argument("--block_size", type=int, default=256) p.add_argument("--chunk_size", type=int, default=64) p.add_argument("--active_fraction", type=float, default=0.10) p.add_argument("--warmup_steps", type=int, default=50) p.add_argument("--anneal_steps", type=int, default=200) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--backend", default="triton", choices=["triton", "torch"]) p.add_argument("--output_dir", default="results") args = p.parse_args() seeds = [int(s) for s in args.seeds.split(",")] os.makedirs(args.output_dir, exist_ok=True) if args.device == "cuda" and torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") print(f"Config: d={args.n_embd} nl={args.n_layer} nh={args.n_head} steps={args.steps} seeds={seeds}") print(f" cs={args.chunk_size} af={args.active_fraction} backend={args.backend}") shared = dict(device=args.device, steps=args.steps, seeds=seeds, d=args.n_embd, nl=args.n_layer, nh=args.n_head, bs=args.batch_size, block_size=args.block_size, cs=args.chunk_size, af=args.active_fraction, wu=args.warmup_steps, an=args.anneal_steps, lr=args.lr, backend=args.backend) exps = ALL_EXPS if args.experiment == "all" else {args.experiment: ALL_EXPS[args.experiment]} t0 = time.time() for name, fn in exps.items(): print(f"\n{'#'*80}\n# {name} ({(time.time()-t0)/60:.1f}m elapsed)\n{'#'*80}") sys.stdout.flush() result = fn(**shared) def ser(o): if isinstance(o, dict): return {str(k): ser(v) for k,v in o.items()} if isinstance(o, list): return [ser(x) for x in o] return o with open(os.path.join(args.output_dir, f"{name}.json"), "w") as f: json.dump(ser(result), f, indent=2, default=str) print(f"✓ {name} saved to {args.output_dir}/{name}.json") print(f"\n{'='*80}\nALL COMPLETE in {(time.time()-t0)/60:.1f} minutes\n{'='*80}") if __name__ == "__main__": main()