theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
"""
Sparse Transformer: Real-World Benchmark on Tiny Shakespeare using GPT-2 BPE.
This script scales the architecture to a 6-layer, 512-dim GPT and trains on
real natural language. It applies our Hardware-Sympathetic Chunked Sparse
backward pass, Cosine Annealing, and Chunked Adam optimizer.
Run:
python3 sparse_transformer_shakespeare.py --device mps --benchmark_sync
"""
import argparse
import math
import os
import random
import time
import urllib.request
from typing import Dict, List, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import tiktoken
except ImportError:
raise ImportError("Please install tiktoken: pip install tiktoken")
torch.set_num_threads(1)
def sync_device(device: str) -> None:
if device == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif device == "mps" and hasattr(torch, "mps"):
torch.mps.synchronize()
Policy = Literal["predicted_magnitude", "oracle_current", "random"]
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
def make_cpu_generator(seed: int) -> torch.Generator:
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
return gen
# -----------------------------
# Real-World Data Pipeline
# -----------------------------
class ShakespeareCorpus:
def __init__(self, block_size: int, device: str):
self.block_size = block_size
self.device = device
# 1. Download Tiny Shakespeare if not exists
data_path = "input.txt"
if not os.path.exists(data_path):
print("Downloading Tiny Shakespeare...")
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, data_path)
# 2. Tokenize using GPT-2 BPE
print("Tokenizing data...")
with open(data_path, "r", encoding="utf-8") as f:
text = f.read()
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(text)
self.vocab_size = enc.n_vocab
# 3. Split 90/10 Train/Val
data = torch.tensor(tokens, dtype=torch.long)
split_idx = int(0.9 * len(data))
self.train_data = data[:split_idx]
self.val_data = data[split_idx:]
print(f"Dataset loaded. Vocab size: {self.vocab_size:,}. Train tokens: {len(self.train_data):,}")
def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
data = self.train_data if split == "train" else self.val_data
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
x = torch.stack([data[i : i + self.block_size] for i in ix])
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
return x.to(self.device), y.to(self.device)
# -----------------------------
# Chunked Sparse Autograd
# -----------------------------
class ChunkedMaskedLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
ctx.save_for_backward(x, weight, active_chunks)
ctx.has_bias = bias is not None
ctx.sparse_dx = sparse_dx
ctx.chunk_size = chunk_size
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_y: torch.Tensor):
x, weight, active_chunks = ctx.saved_tensors
chunk_size = ctx.chunk_size
x_flat = x.reshape(-1, x.shape[-1])
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
grad_w = torch.zeros_like(weight)
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
if ctx.sparse_dx:
grad_x_flat = torch.zeros_like(x_flat)
else:
grad_x_flat = gy_flat @ weight
# Zero-copy Strided Views feeding directly into Dense Hardware Matmuls
for c_idx in active_chunks.tolist():
start = c_idx * chunk_size
end = start + chunk_size
gy_slice = gy_flat[:, start:end]
w_slice = weight[start:end, :]
grad_w[start:end, :] = gy_slice.t() @ x_flat
if ctx.has_bias:
grad_b[start:end] = gy_slice.sum(dim=0)
if ctx.sparse_dx:
grad_x_flat += gy_slice @ w_slice
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
class SparseLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__(in_features, out_features, bias=bias)
self.sparse_enabled = False
self.sparse_dx = False
self.active_chunks: Optional[torch.Tensor] = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.sparse_enabled or self.active_chunks is None:
return F.linear(x, self.weight, self.bias)
return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
# -----------------------------
# GPT Architecture
# -----------------------------
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
self.c_proj = SparseLinear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.c_attn(x)
q, k, v = qkv.split(C, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class FeedForward(nn.Module):
def __init__(self, n_embd: int, dropout: float):
super().__init__()
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
self.c_proj = SparseLinear(4 * n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = FeedForward(n_embd, dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
# LM head is Dense! Needs full output dist for CrossEntropy loss
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def get_sparse_linears(model):
return[m for m in model.modules() if isinstance(m, SparseLinear)]
# -----------------------------
# Chunk Masker with Annealing
# -----------------------------
class ChunkMasker:
def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str):
self.policy = policy
self.target_fraction = target_fraction
self.chunk_size = chunk_size
self.device = device
self.linears = get_sparse_linears(model)
self.module_to_chunk_ids = {}
offset = 0
for m in self.linears:
assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
n_chunks = m.out_features // chunk_size
self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
offset += n_chunks
self.n_chunks = offset
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
if step < warmup_steps:
current_fraction = 1.0
elif step < warmup_steps + anneal_steps:
progress = (step - warmup_steps) / anneal_steps
cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
else:
current_fraction = self.target_fraction
if current_fraction >= 0.999:
self.active_chunks.fill_(True)
for m, ids in self.module_to_chunk_ids.items():
m.active_chunks = torch.arange(len(ids), device=self.device)
return
k = max(1, int(current_fraction * self.n_chunks))
self.active_chunks.fill_(False)
if self.policy == "random":
self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
elif self.policy == "predicted_magnitude":
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
self.active_chunks[torch.topk(scores, k=k).indices] = True
for m, ids in self.module_to_chunk_ids.items():
global_active = self.active_chunks[ids]
local_ids = torch.arange(len(ids), device=self.device)
m.active_chunks = local_ids[global_active]
@torch.no_grad()
def update_predictor(self, mass_beta=0.95):
current_mass = torch.zeros_like(self.predicted_mass)
for m, ids in self.module_to_chunk_ids.items():
if m.weight.grad is None: continue
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
if m.bias is not None and m.bias.grad is not None:
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
observed = self.active_chunks
self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
# -----------------------------
# Chunked Adam
# -----------------------------
class ChunkedAdam:
def __init__(self, model, lr=5e-4, chunk_size=64):
self.model = model
self.lr = lr
self.chunk_size = chunk_size
self.state = {}
self.param_to_sparse_module = {}
for m in get_sparse_linears(model):
if m.weight is not None: self.param_to_sparse_module[m.weight] = m
if m.bias is not None: self.param_to_sparse_module[m.bias] = m
def zero_grad(self):
for p in self.model.parameters(): p.grad = None
@torch.no_grad()
def step(self):
for p in self.model.parameters():
if p.grad is None: continue
if p not in self.state:
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
sparse_module = self.param_to_sparse_module.get(p)
active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
if active_chunks is None:
# Dense update
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
p.sub_(update, alpha=self.lr)
else:
# Sparse update
for local_c in active_chunks.tolist():
start = local_c * self.chunk_size
end = (local_c + 1) * self.chunk_size
p_chunk = p[start:end]
g_chunk = p.grad[start:end]
m_chunk = exp_avg[start:end]
v_chunk = exp_avg_sq[start:end]
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
p_chunk.sub_(update, alpha=self.lr)
# -----------------------------
# Training
# -----------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--block_size", type=int, default=256)
parser.add_argument("--n_layer", type=int, default=6)
parser.add_argument("--n_head", type=int, default=8)
parser.add_argument("--n_embd", type=int, default=512)
parser.add_argument("--chunk_size", type=int, default=64)
parser.add_argument("--active_fraction", type=float, default=0.10)
parser.add_argument("--warmup_steps", type=int, default=50)
parser.add_argument("--anneal_steps", type=int, default=200)
parser.add_argument("--device", type=str, default="mps")
parser.add_argument("--benchmark_sync", action="store_true")
args = parser.parse_args()
corpus = ShakespeareCorpus(args.block_size, args.device)
modes =[
("dense_baseline", "dense_baseline"),
("predicted_magnitude", "sparse_dW_full_dX"),
("predicted_magnitude", "sparse_dW_sparse_dX")
]
print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active: {args.active_fraction*100}%")
print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n")
print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
print("-" * 55)
for policy, bwd_mode in modes:
set_seed(42)
model = GPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.1).to(args.device)
for m in get_sparse_linears(model):
m.chunk_size = args.chunk_size
masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
opt = ChunkedAdam(model, lr=5e-4, chunk_size=args.chunk_size)
if args.benchmark_sync: sync_device(args.device)
t0 = time.perf_counter()
measured_steps = args.steps
for step in range(args.steps):
if step == args.warmup_steps + args.anneal_steps:
if args.benchmark_sync: sync_device(args.device)
t0 = time.perf_counter()
measured_steps = args.steps - step
x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
if masker:
masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps)
for m in get_sparse_linears(model):
m.sparse_enabled = True
m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
else:
for m in get_sparse_linears(model):
m.sparse_enabled = False
m.active_chunks = None
opt.zero_grad()
_, loss = model(x, y)
loss.backward()
if masker:
masker.update_predictor()
opt.step()
# Optional: Print progress every 100 steps
if step % 200 == 0:
print(f" [Progress] {bwd_mode} step {step}/{args.steps} | Loss: {loss.item():.4f}", end="\r")
if args.benchmark_sync: sync_device(args.device)
t_elapsed = time.perf_counter() - t0
# Eval loss
model.eval()
with torch.no_grad():
# Eval loss
model.eval()
with torch.no_grad():
val_x, val_y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
_, val_loss = model(val_x, val_y)
# Clear the progress line
print(" " * 60, end="\r")
bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}")
if __name__ == "__main__":
main()