theapemachine's picture
Upload ablations.py with huggingface_hub
de14582 verified
#!/usr/bin/env python3
"""
Sparse Transformer: Definitive Ablation Suite
Builds on v18_fast_knn_triton.py. Addresses all three structural gaps
identified in the critique:
1. PHANTOM MOMENTUM ABLATION
- "phantom": standard Adam β€” inactive chunks' moments decay on zero grad (default)
- "frozen": inactive chunks' Adam state (m, v) is completely frozen
Compare across all schedulers to isolate whether convergence is driven
by the chunking algorithm or by phantom momentum acting as regularization.
2. COMPUTE-MATCHED BASELINES
- Dense at same steps (standard comparison)
- Dense at fewer steps matching sparse FLOPs
- Natively smaller dense model matching sparse active capacity
3. UNIFIED HARDWARE
Everything on CUDA (A10G). Single hardware stack.
Plus: KNN vs EMA vs Random vs Oracle predictor comparison with proper
oracle overlap measurement.
Run:
python ablations.py --device cuda --steps 1000 --n_embd 1024 --experiment all
python ablations.py --device cuda --experiment phantom_momentum
python ablations.py --device cuda --experiment compute_matched
python ablations.py --device cuda --experiment predictor_accuracy
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import sys
import time
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
try:
import tiktoken
HAS_TIKTOKEN = True
except ImportError:
HAS_TIKTOKEN = False
# ═══════════════════════════════════════════════════════════════
# TRITON KERNELS (from v18_triton, no autotune, block_ptr)
# ═══════════════════════════════════════════════════════════════
if HAS_TRITON:
@triton.jit
def _sparse_bwd_dW_db_kernel(
X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr,
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
num_active: tl.constexpr,
stride_xm: tl.constexpr, stride_xk: tl.constexpr,
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
stride_dwn: tl.constexpr, stride_dwk: tl.constexpr,
HAS_BIAS: tl.constexpr,
CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
):
cli = tl.program_id(0)
kbi = tl.program_id(1)
cidx = tl.load(chunk_ids_ptr + cli)
cs0 = cidx * CS
ko = kbi * BK
dy_bp = tl.make_block_ptr(dY_ptr, (d_out, M), (stride_dyn, stride_dym),
(cs0, 0), (CS, BM), (1, 0))
x_bp = tl.make_block_ptr(X_ptr, (M, d_in), (stride_xm, stride_xk),
(0, ko), (BM, BK), (1, 0))
acc = tl.zeros((CS, BK), dtype=tl.float32)
do_bias = HAS_BIAS and (kbi == 0)
acc_b = tl.zeros((CS,), dtype=tl.float32)
for _ in range(0, M, BM):
dy_t = tl.load(dy_bp, boundary_check=(0, 1))
x = tl.load(x_bp, boundary_check=(0, 1))
acc = tl.dot(dy_t, x, acc=acc)
if do_bias:
acc_b += tl.sum(dy_t, axis=1)
dy_bp = tl.advance(dy_bp, (0, BM))
x_bp = tl.advance(x_bp, (BM, 0))
dw_bp = tl.make_block_ptr(dW_ptr, (d_out, d_in), (stride_dwn, stride_dwk),
(cs0, ko), (CS, BK), (1, 0))
tl.store(dw_bp, acc.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1))
if do_bias:
rn = cs0 + tl.arange(0, CS)
tl.store(dB_ptr + rn, acc_b.to(dB_ptr.dtype.element_ty), mask=rn < d_out)
@triton.jit
def _sparse_bwd_dX_kernel(
dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
num_active: tl.constexpr,
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
stride_wn: tl.constexpr, stride_wk: tl.constexpr,
stride_dxm: tl.constexpr, stride_dxk: tl.constexpr,
CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr,
):
pm = tl.program_id(0)
pk = tl.program_id(1)
mo = pm * BM
ko = pk * BK
acc = tl.zeros((BM, BK), dtype=tl.float32)
for i in range(0, num_active):
cidx = tl.load(chunk_ids_ptr + i)
cs0 = cidx * CS
dy_bp = tl.make_block_ptr(dY_ptr, (M, d_out), (stride_dym, stride_dyn),
(mo, cs0), (BM, CS), (1, 0))
w_bp = tl.make_block_ptr(W_ptr, (d_out, d_in), (stride_wn, stride_wk),
(cs0, ko), (CS, BK), (1, 0))
dy = tl.load(dy_bp, boundary_check=(0, 1))
w = tl.load(w_bp, boundary_check=(0, 1))
acc = tl.dot(dy, w, acc=acc)
dx_bp = tl.make_block_ptr(dX_ptr, (M, d_in), (stride_dxm, stride_dxk),
(mo, ko), (BM, BK), (1, 0))
tl.store(dx_bp, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1))
def triton_bwd_dW_db(xf, gyf, active, cs, d_out, has_bias):
M, d_in = xf.shape
na = active.numel()
dW = torch.zeros(d_out, d_in, device=xf.device, dtype=xf.dtype)
dB = torch.zeros(d_out, device=xf.device, dtype=xf.dtype) if has_bias else None
if na == 0: return dW, dB
cids = active.to(torch.int32).contiguous()
BK, BM = 64, 64
_sparse_bwd_dW_db_kernel[(na, triton.cdiv(d_in, BK))](
xf, gyf, dW, dB if has_bias else dW, cids,
M, d_in, d_out, na,
xf.stride(0), xf.stride(1), gyf.stride(0), gyf.stride(1),
dW.stride(0), dW.stride(1),
HAS_BIAS=has_bias, CS=cs, BK=BK, BM=BM, num_warps=4)
return dW, dB
def triton_bwd_dX(gyf, w, active, cs, M, d_in):
na = active.numel()
d_out = gyf.shape[1]
dX = torch.zeros(M, d_in, device=gyf.device, dtype=gyf.dtype)
if na == 0: return dX
cids = active.to(torch.int32).contiguous()
BM, BK = 64, 64
_sparse_bwd_dX_kernel[(triton.cdiv(M, BM), triton.cdiv(d_in, BK))](
gyf, w, dX, cids,
M, d_in, d_out, na,
gyf.stride(0), gyf.stride(1), w.stride(0), w.stride(1),
dX.stride(0), dX.stride(1),
CS=cs, BM=BM, BK=BK, num_warps=4)
return dX
# ═══════════════════════════════════════════════════════════════
# AUTOGRAD
# ═══════════════════════════════════════════════════════════════
class TritonSparseLinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, b, active, cs, sparse_dx):
ctx.save_for_backward(x, w, active)
ctx.has_bias = b is not None
ctx.sparse_dx = sparse_dx
ctx.cs = cs
return F.linear(x, w, b)
@staticmethod
def backward(ctx, gy):
x, w, active = ctx.saved_tensors
cs = ctx.cs
do, di = w.shape
xf = x.reshape(-1, di).contiguous()
gf = gy.reshape(-1, do).contiguous()
M = xf.shape[0]
gw, gb = triton_bwd_dW_db(xf, gf, active, cs, do, ctx.has_bias)
gx = triton_bwd_dX(gf, w.contiguous(), active, cs, M, di) if ctx.sparse_dx else gf @ w
return gx.reshape(x.shape), gw, gb, None, None, None
class PyLoopSparseLinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, b, active, cs, sparse_dx):
ctx.save_for_backward(x, w, active)
ctx.has_bias = b is not None
ctx.sparse_dx = sparse_dx
ctx.cs = cs
return F.linear(x, w, b)
@staticmethod
def backward(ctx, gy):
x, w, active = ctx.saved_tensors
cs = ctx.cs
xf = x.reshape(-1, x.shape[-1])
gf = gy.reshape(-1, gy.shape[-1])
gw = torch.zeros_like(w)
gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None
gx = torch.zeros_like(xf) if ctx.sparse_dx else gf @ w
for c in active.tolist():
s, e = c * cs, (c+1) * cs
sl = gf[:, s:e]
gw[s:e] = sl.t() @ xf
if gb is not None: gb[s:e] = sl.sum(0)
if ctx.sparse_dx: gx += sl @ w[s:e]
return gx.reshape(x.shape), gw, gb, None, None, None
# ═══════════════════════════════════════════════════════════════
# MODEL
# ═══════════════════════════════════════════════════════════════
class SparseLinear(nn.Linear):
def __init__(self, inf, outf, bias=True):
super().__init__(inf, outf, bias=bias)
self.sparse_enabled = False
self.sparse_dx = False
self.active_chunks = None
self.chunk_size = 64
self.backend = "triton" # "triton" or "torch"
def forward(self, x):
if not self.sparse_enabled or self.active_chunks is None:
return F.linear(x, self.weight, self.bias)
fn = TritonSparseLinearFn if (self.backend == "triton" and HAS_TRITON) else PyLoopSparseLinearFn
return fn.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx)
class Attn(nn.Module):
def __init__(self, d, nh, bs, do):
super().__init__()
self.nh, self.hd = nh, d // nh
self.c_attn = SparseLinear(d, 3*d)
self.c_proj = SparseLinear(d, d)
self.drop = nn.Dropout(do)
self.register_buffer("mask", torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
def forward(self, x):
B,T,C = x.shape
q,k,v = self.c_attn(x).split(C, 2)
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
a = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd)
a = a.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
a = self.drop(F.softmax(a, dim=-1))
return self.c_proj((a @ v).transpose(1,2).contiguous().view(B,T,C))
class FFN(nn.Module):
def __init__(self, d, do, ffn_mult=4):
super().__init__()
self.c_fc = SparseLinear(d, ffn_mult * d)
self.c_proj = SparseLinear(ffn_mult * d, d)
self.drop = nn.Dropout(do)
def forward(self, x):
return self.drop(self.c_proj(F.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, d, nh, bs, do, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d); self.attn = Attn(d, nh, bs, do)
self.ln2 = nn.LayerNorm(d); self.mlp = FFN(d, do, ffn_mult)
def forward(self, x):
x = x + self.attn(self.ln1(x))
return x + self.mlp(self.ln2(x))
class GPT(nn.Module):
def __init__(self, V, bs, nl, nh, d, do, ffn_mult=4):
super().__init__()
self.te = nn.Embedding(V, d); self.pe = nn.Embedding(bs, d)
self.blocks = nn.Sequential(*[Block(d, nh, bs, do, ffn_mult) for _ in range(nl)])
self.ln = nn.LayerNorm(d); self.head = nn.Linear(d, V)
def forward(self, idx, tgt=None):
B,T = idx.shape
x = self.te(idx) + self.pe(torch.arange(T, device=idx.device))[None]
lo = self.head(self.ln(self.blocks(x)))
loss = F.cross_entropy(lo.view(-1, lo.size(-1)), tgt.view(-1)) if tgt is not None else None
return lo, loss
def nparams(self): return sum(p.numel() for p in self.parameters())
def get_sparse_linears(m): return [x for x in m.modules() if isinstance(x, SparseLinear)]
# ═══════════════════════════════════════════════════════════════
# DATA
# ═══════════════════════════════════════════════════════════════
class Corpus:
"""Uses tiktoken GPT-2 BPE on Tiny Shakespeare if available, else char-level synthetic."""
_inst = None
@classmethod
def get(cls, bs, dev):
if cls._inst is None or cls._inst.block_size != bs:
cls._inst = cls(bs, dev)
return cls._inst
def __init__(self, block_size, device):
self.block_size, self.device = block_size, device
import urllib.request
p = "input.txt"
if not os.path.exists(p):
urllib.request.urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", p)
text = open(p).read()
if HAS_TIKTOKEN:
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(text)
self.vocab_size = enc.n_vocab
else:
chars = sorted(set(text))
stoi = {c:i for i,c in enumerate(chars)}
tokens = [stoi[c] for c in text]
self.vocab_size = len(chars)
data = torch.tensor(tokens, dtype=torch.long)
si = int(0.9 * len(data))
self.train_data, self.val_data = data[:si], data[si:]
print(f"Corpus: V={self.vocab_size}, train={len(self.train_data):,}, val={len(self.val_data):,}")
def get_batch(self, split, bs, gen=None):
d = self.train_data if split == "train" else self.val_data
ix = torch.randint(len(d)-self.block_size-1, (bs,), generator=gen)
x = torch.stack([d[i:i+self.block_size] for i in ix])
y = torch.stack([d[i+1:i+self.block_size+1] for i in ix])
return x.to(self.device), y.to(self.device)
def make_gen(s):
g = torch.Generator(device="cpu"); g.manual_seed(s); return g
# ═══════════════════════════════════════════════════════════════
# SCHEDULER (from v18, with KNN)
# ═══════════════════════════════════════════════════════════════
class ChunkScheduler:
def __init__(self, model, policy, frac, cs, dev, beta=0.95, knn_k=3,
sim_hist=128, min_sim_hist=8):
self.policy, self.frac, self.cs, self.dev = policy, frac, cs, dev
self.beta, self.knn_k = beta, knn_k
self.sim_hist, self.min_sim_hist = sim_hist, min_sim_hist
self.linears = get_sparse_linears(model)
self.m2ids, self.m2loc = {}, {}
off = 0
for m in self.linears:
m.chunk_size = cs
nc = m.out_features // cs
assert m.out_features % cs == 0
self.m2ids[m] = torch.arange(off, off+nc, device=dev)
self.m2loc[m] = torch.arange(nc, device=dev)
off += nc
self.nc = off
self.ema = torch.zeros(self.nc, device=dev)
self.active = torch.zeros(self.nc, dtype=torch.bool, device=dev)
self.mass_history = []
self.similarity = None
self.scores = torch.zeros(self.nc, device=dev)
def get_frac(self, step, wu, an):
if step < wu: return 1.0
if an > 0 and step < wu + an:
p = (step - wu) / an
return self.frac + (1-self.frac) * 0.5 * (1 + math.cos(math.pi * p))
return self.frac
def choose(self, step, wu, an):
f = self.get_frac(step, wu, an)
if f >= 0.999:
self.active.fill_(True)
self._install(); return
k = max(1, int(f * self.nc))
self.active.fill_(False)
if self.policy == "random":
idx = torch.randperm(self.nc, device=self.dev)[:k]
elif self.policy == "ema":
idx = torch.topk(self.ema + 1e-9*torch.rand_like(self.ema), k=k).indices
elif self.policy == "knn":
base = self.scores if self.scores.sum() > 1e-12 else self.ema
idx = torch.topk(base + 1e-9*torch.rand_like(base), k=k).indices
else:
raise ValueError(self.policy)
self.active[idx] = True
self._install()
def _install(self):
for m, gids in self.m2ids.items():
m.active_chunks = self.m2loc[m][self.active[gids]]
@torch.no_grad()
def update(self, step, wu):
cur = torch.zeros_like(self.ema)
for m, ids in self.m2ids.items():
if m.weight.grad is None: continue
s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
if m.bias is not None and m.bias.grad is not None:
s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
cur[ids] = torch.sqrt(s + 1e-30)
obs = self.active
new = obs & (self.ema == 0)
old = obs & ~new
self.ema[new] = cur[new]
self.ema[old] = self.beta*self.ema[old] + (1-self.beta)*cur[old]
# KNN similarity building during warmup
if step < wu:
self.mass_history.append(cur.clone())
if len(self.mass_history) > self.sim_hist:
self.mass_history = self.mass_history[-self.sim_hist:]
if len(self.mass_history) >= self.min_sim_hist:
self.similarity = self._build_sim()
if self.policy == "knn":
self.scores = self._knn_scores(self.active, cur)
else:
self.scores = self.ema.clone()
return cur
def _build_sim(self):
H = torch.stack(self.mass_history)
H = (H - H.mean(0, keepdim=True)) / (H.std(0, keepdim=True) + 1e-6)
S = torch.clamp((H.T @ H) / max(1, H.shape[0]-1), min=0)
S.fill_diagonal_(0)
ok = torch.zeros_like(S, dtype=torch.bool)
for _, ids in self.m2ids.items():
ok[ids[:,None], ids[None,:]] = True
return torch.where(ok, S, torch.zeros_like(S))
def _knn_scores(self, active_mask, cur):
if self.similarity is None: return self.ema.clone()
sc = self.ema.clone()
sc[active_mask] = cur[active_mask]
aidx = active_mask.nonzero(as_tuple=False).flatten()
iidx = (~active_mask).nonzero(as_tuple=False).flatten()
if aidx.numel() == 0: return sc
S = self.similarity
for i in iidx.tolist():
w = S[i, aidx]
if w.sum() <= 1e-12: continue
kk = min(self.knn_k, w.numel())
top = torch.topk(w, k=kk)
sc[i] = (top.values * cur[aidx[top.indices]]).sum() / (top.values.sum() + 1e-12)
return sc
@torch.no_grad()
def oracle_scores(self):
"""Compute dense gradient magnitudes per chunk (requires dense grads already computed)."""
sc = torch.zeros(self.nc, device=self.dev)
for m, ids in self.m2ids.items():
if m.weight.grad is None: continue
s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
if m.bias is not None and m.bias.grad is not None:
s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
sc[ids] = torch.sqrt(s + 1e-30)
return sc
def measure_overlap(self, k):
"""Jaccard and recall of current active vs oracle top-k."""
oracle = set(torch.topk(self.oracle_scores(), k=k).indices.tolist())
pred = set(self.active.nonzero(as_tuple=True)[0].tolist())
if not oracle or not pred: return 0., 0.
inter = oracle & pred
return len(inter)/len(oracle|pred), len(inter)/len(oracle)
# ═══════════════════════════════════════════════════════════════
# CHUNKED ADAM WITH PHANTOM/FROZEN MODES
# ═══════════════════════════════════════════════════════════════
class ChunkedAdam:
"""
Adam with two modes for inactive chunks:
phantom: standard β€” m,v decay even on zero grad (default, original behavior)
frozen: m,v state completely frozen for inactive chunks
"""
def __init__(self, model, lr=3e-4, cs=64, momentum_mode="phantom"):
self.model, self.lr, self.cs = model, lr, cs
self.momentum_mode = momentum_mode # "phantom" or "frozen"
self.state = {}
self.p2m = {}
for m in get_sparse_linears(model):
if m.weight is not None: self.p2m[m.weight] = m
if m.bias is not None: self.p2m[m.bias] = m
def zero_grad(self):
for p in self.model.parameters(): p.grad = None
@torch.no_grad()
def step(self):
for p in self.model.parameters():
if p.grad is None: continue
if p not in self.state:
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
m, v = self.state[p]["m"], self.state[p]["v"]
sm = self.p2m.get(p)
ac = getattr(sm, 'active_chunks', None) if sm else None
if ac is None:
# Dense parameter (LN, embeddings, lm_head) β€” always full update
m.mul_(0.9).add_(p.grad, alpha=0.1)
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
else:
if self.momentum_mode == "phantom":
# PHANTOM: update ALL chunks' moments, but only active get real gradients.
# Inactive chunks see grad=0, so m decays and v decays.
# This is the original behavior.
m.mul_(0.9).add_(p.grad, alpha=0.1)
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
# But only update weights for active chunks
for c in ac.tolist():
s, e = c*self.cs, (c+1)*self.cs
p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
elif self.momentum_mode == "frozen":
# FROZEN: only touch m,v,p for active chunks. Inactive state is untouched.
for c in ac.tolist():
s, e = c*self.cs, (c+1)*self.cs
g = p.grad[s:e]
m[s:e].mul_(0.9).add_(g, alpha=0.1)
v[s:e].mul_(0.999).addcmul_(g, g, value=0.001)
p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
# ═══════════════════════════════════════════════════════════════
# EVALUATION
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def evaluate(model, corpus, bs, n=20, seed=9999):
model.eval()
losses = []
for i in range(n):
_, l = model(*corpus.get_batch("val", bs, make_gen(seed+i)))
losses.append(l.item())
model.train()
avg = sum(losses)/len(losses)
return avg, math.exp(min(avg, 20))
# ═══════════════════════════════════════════════════════════════
# SINGLE TRAINING RUN
# ═══════════════════════════════════════════════════════════════
def run(policy, bwd_mode, steps, bs, block_size, nl, nh, d, cs,
active_frac, wu, an, lr, device, seed, backend="triton",
momentum_mode="phantom", ffn_mult=4,
measure_oracle=False, oracle_interval=50):
"""Run one training config. Returns dict of results."""
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
random.seed(seed)
corpus = Corpus.get(block_size, device)
model = GPT(corpus.vocab_size, block_size, nl, nh, d, 0.1, ffn_mult).to(device)
for m in get_sparse_linears(model):
m.chunk_size = cs
m.backend = backend
is_dense = (policy == "dense")
sched = None if is_dense else ChunkScheduler(model, policy, active_frac, cs, device)
opt = ChunkedAdam(model, lr=lr, cs=cs, momentum_mode=momentum_mode)
np_ = model.nparams()
overlaps = []
torch.cuda.synchronize() if device == "cuda" else None
t0 = time.perf_counter()
for step in range(steps):
x, y = corpus.get_batch("train", bs, make_gen(step))
if is_dense:
for m in get_sparse_linears(model):
m.sparse_enabled = False; m.active_chunks = None
else:
sched.choose(step, wu, an)
for m in get_sparse_linears(model):
m.sparse_enabled = True
m.sparse_dx = (bwd_mode == "sparse_dX")
opt.zero_grad()
_, loss = model(x, y)
loss.backward()
if sched:
sched.update(step, wu)
# Oracle overlap measurement
if measure_oracle and step % oracle_interval == 0 and step >= wu + an:
saved = {p: p.grad.clone() for p in model.parameters() if p.grad is not None}
for m in get_sparse_linears(model): m.sparse_enabled = False
for p in model.parameters(): p.grad = None
_, lo = model(x, y); lo.backward()
k = max(1, int(active_frac * sched.nc))
j, r = sched.measure_overlap(k)
overlaps.append((step, j, r))
for p in model.parameters():
if p in saved: p.grad = saved[p]
for m in get_sparse_linears(model): m.sparse_enabled = True
opt.step()
if step % 200 == 0:
print(f" step {step}/{steps} loss={loss.item():.4f}")
torch.cuda.synchronize() if device == "cuda" else None
wall = time.perf_counter() - t0
for m in get_sparse_linears(model): m.sparse_enabled = False
vl, vp = evaluate(model, corpus, bs, n=30)
del model; torch.cuda.empty_cache() if device == "cuda" else None
return {
"val_loss": vl, "val_ppl": vp, "wall_time": wall,
"ms_per_step": 1000*wall/steps, "n_params": np_,
"train_loss_final": loss.item(), "overlaps": overlaps,
}
def run_seeds(cfg, seeds):
results = []
for s in seeds:
cfg["seed"] = s
results.append(run(**cfg))
vls = [r["val_loss"] for r in results]
ml = sum(vls)/len(vls)
sl = (sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
return {"mean_loss": ml, "std_loss": sl, "results": results,
"mean_ms": sum(r["ms_per_step"] for r in results)/len(results)}
# ═══════════════════════════════════════════════════════════════
# EXPERIMENT 1: PHANTOM MOMENTUM ABLATION
# ═══════════════════════════════════════════════════════════════
def exp_phantom_momentum(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
print("\n" + "="*80)
print("EXPERIMENT 1: Phantom Momentum Ablation")
print("="*80)
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
lr=lr, device=device, backend=backend)
configs = [
("dense", "dense", "phantom"),
("ema+phantom", "ema", "phantom"),
("ema+frozen", "ema", "frozen"),
("knn+phantom", "knn", "phantom"),
("knn+frozen", "knn", "frozen"),
("random+phantom", "random", "phantom"),
("random+frozen", "random", "frozen"),
]
results = {}
for name, policy, mm in configs:
print(f"\n--- {name} ---")
cfg = {**base, "policy": policy, "momentum_mode": mm}
results[name] = run_seeds(cfg, seeds)
print(f"\n{'Method':<22} | {'Val Loss':>18} | {'ms/step':>10}")
print("-"*55)
for name, _, _ in configs:
r = results[name]
print(f"{name:<22} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
return results
# ═══════════════════════════════════════════════════════════════
# EXPERIMENT 2: COMPUTE-MATCHED BASELINES
# ═══════════════════════════════════════════════════════════════
def exp_compute_matched(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
print("\n" + "="*80)
print("EXPERIMENT 2: Compute-Matched Baselines")
print("="*80)
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
lr=lr, device=device, backend=backend, momentum_mode="phantom")
# 1. Sparse reference
print("\n--- Sparse (EMA, reference) ---")
sparse_r = run_seeds({**base, "policy": "ema"}, seeds)
# 2. Dense at same steps
print("\n--- Dense (same steps) ---")
dense_same = run_seeds({**base, "policy": "dense"}, seeds)
# 3. Dense at compute-matched steps
# Sparse does ~70% of dense FLOPs (fwd dense + dX dense + dW at 10%)
ratio = (1.0 + 1.0 + af) / 3.0
matched_steps = int(steps * ratio)
print(f"\n--- Dense (compute-matched, {matched_steps} steps) ---")
dense_matched = run_seeds({**base, "policy": "dense", "steps": matched_steps}, seeds)
# 4. Natively smaller dense model: FFN multiplier = 4 * af = 0.4 (rounded)
# This gives a model with ~10% of the FFN capacity
small_ffn_mult = max(1, round(4 * af)) # 4*0.1 = 0.4, round to 1
print(f"\n--- Small dense (ffn_mult={small_ffn_mult}, capacity-matched) ---")
dense_small = run_seeds({**base, "policy": "dense", "ffn_mult": small_ffn_mult}, seeds)
results = {
"sparse_ema": sparse_r,
"dense_same_steps": dense_same,
f"dense_matched_{matched_steps}steps": dense_matched,
f"dense_small_ffn{small_ffn_mult}": dense_small,
}
print(f"\n{'Method':<35} | {'Steps':>6} | {'Params':>8} | {'Val Loss':>18} | {'ms/step':>10}")
print("-"*90)
for name, r in results.items():
np_ = r["results"][0]["n_params"]
st = r["results"][0].get("steps", steps) if "steps" in name else steps
# read actual steps from config β€” approximate
print(f"{name:<35} | {st if 'matched' not in name else matched_steps:>6} | {np_/1e6:>7.1f}M | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
return results
# ═══════════════════════════════════════════════════════════════
# EXPERIMENT 3: PREDICTOR ACCURACY (EMA vs KNN vs Oracle)
# ═══════════════════════════════════════════════════════════════
def exp_predictor_accuracy(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
print("\n" + "="*80)
print("EXPERIMENT 3: Predictor Accuracy (EMA vs KNN vs Oracle)")
print("="*80)
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
lr=lr, device=device, backend=backend, momentum_mode="phantom",
measure_oracle=True, oracle_interval=25)
results = {}
for policy in ["ema", "knn", "random"]:
print(f"\n--- {policy} ---")
results[policy] = run_seeds({**base, "policy": policy}, seeds)
# Aggregate overlaps
for policy in ["ema", "knn", "random"]:
print(f"\n{policy.upper()} predictor overlap:")
print(f" {'Step':>6} | {'Jaccard':>10} | {'Recall':>10}")
sd = defaultdict(lambda: {"j": [], "r": []})
for res in results[policy]["results"]:
for s, j, r in res["overlaps"]:
sd[s]["j"].append(j); sd[s]["r"].append(r)
for s in sorted(sd):
mj = sum(sd[s]["j"])/len(sd[s]["j"])
mr = sum(sd[s]["r"])/len(sd[s]["r"])
print(f" {s:>6} | {mj:>10.4f} | {mr:>10.4f}")
print(f"\n{'Policy':<10} | {'Val Loss':>18} | {'ms/step':>10}")
print("-"*45)
for p in ["ema", "knn", "random"]:
r = results[p]
print(f"{p:<10} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
return results
# ═══════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════
ALL_EXPS = {
"phantom_momentum": exp_phantom_momentum,
"compute_matched": exp_compute_matched,
"predictor_accuracy": exp_predictor_accuracy,
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--experiment", default="all", choices=list(ALL_EXPS)+["all"])
p.add_argument("--device", default="cuda")
p.add_argument("--steps", type=int, default=1000)
p.add_argument("--seeds", default="42,123,456")
p.add_argument("--n_embd", type=int, default=1024)
p.add_argument("--n_layer", type=int, default=4)
p.add_argument("--n_head", type=int, default=8)
p.add_argument("--batch_size", type=int, default=8)
p.add_argument("--block_size", type=int, default=256)
p.add_argument("--chunk_size", type=int, default=64)
p.add_argument("--active_fraction", type=float, default=0.10)
p.add_argument("--warmup_steps", type=int, default=50)
p.add_argument("--anneal_steps", type=int, default=200)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--backend", default="triton", choices=["triton", "torch"])
p.add_argument("--output_dir", default="results")
args = p.parse_args()
seeds = [int(s) for s in args.seeds.split(",")]
os.makedirs(args.output_dir, exist_ok=True)
if args.device == "cuda" and torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
print(f"Config: d={args.n_embd} nl={args.n_layer} nh={args.n_head} steps={args.steps} seeds={seeds}")
print(f" cs={args.chunk_size} af={args.active_fraction} backend={args.backend}")
shared = dict(device=args.device, steps=args.steps, seeds=seeds,
d=args.n_embd, nl=args.n_layer, nh=args.n_head,
bs=args.batch_size, block_size=args.block_size,
cs=args.chunk_size, af=args.active_fraction,
wu=args.warmup_steps, an=args.anneal_steps,
lr=args.lr, backend=args.backend)
exps = ALL_EXPS if args.experiment == "all" else {args.experiment: ALL_EXPS[args.experiment]}
t0 = time.time()
for name, fn in exps.items():
print(f"\n{'#'*80}\n# {name} ({(time.time()-t0)/60:.1f}m elapsed)\n{'#'*80}")
sys.stdout.flush()
result = fn(**shared)
def ser(o):
if isinstance(o, dict): return {str(k): ser(v) for k,v in o.items()}
if isinstance(o, list): return [ser(x) for x in o]
return o
with open(os.path.join(args.output_dir, f"{name}.json"), "w") as f:
json.dump(ser(result), f, indent=2, default=str)
print(f"βœ“ {name} saved to {args.output_dir}/{name}.json")
print(f"\n{'='*80}\nALL COMPLETE in {(time.time()-t0)/60:.1f} minutes\n{'='*80}")
if __name__ == "__main__":
main()