""" Sparse Transformer: Real-World Benchmark on Tiny Shakespeare using GPT-2 BPE. This script scales the architecture to a 6-layer, 512-dim GPT and trains on real natural language. It applies our Hardware-Sympathetic Chunked Sparse backward pass, Cosine Annealing, and Chunked Adam optimizer. Run: python3 sparse_transformer_shakespeare.py --device mps --benchmark_sync """ import argparse import math import os import random import time import urllib.request from typing import Dict, List, Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F try: import tiktoken except ImportError: raise ImportError("Please install tiktoken: pip install tiktoken") torch.set_num_threads(1) def sync_device(device: str) -> None: if device == "cuda" and torch.cuda.is_available(): torch.cuda.synchronize() elif device == "mps" and hasattr(torch, "mps"): torch.mps.synchronize() Policy = Literal["predicted_magnitude", "oracle_current", "random"] BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) def make_cpu_generator(seed: int) -> torch.Generator: gen = torch.Generator(device="cpu") gen.manual_seed(seed) return gen # ----------------------------- # Real-World Data Pipeline # ----------------------------- class ShakespeareCorpus: def __init__(self, block_size: int, device: str): self.block_size = block_size self.device = device # 1. Download Tiny Shakespeare if not exists data_path = "input.txt" if not os.path.exists(data_path): print("Downloading Tiny Shakespeare...") url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" urllib.request.urlretrieve(url, data_path) # 2. Tokenize using GPT-2 BPE print("Tokenizing data...") with open(data_path, "r", encoding="utf-8") as f: text = f.read() enc = tiktoken.get_encoding("gpt2") tokens = enc.encode(text) self.vocab_size = enc.n_vocab # 3. Split 90/10 Train/Val data = torch.tensor(tokens, dtype=torch.long) split_idx = int(0.9 * len(data)) self.train_data = data[:split_idx] self.val_data = data[split_idx:] print(f"Dataset loaded. Vocab size: {self.vocab_size:,}. Train tokens: {len(self.train_data):,}") def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]: data = self.train_data if split == "train" else self.val_data ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator) x = torch.stack([data[i : i + self.block_size] for i in ix]) y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) return x.to(self.device), y.to(self.device) # ----------------------------- # Chunked Sparse Autograd # ----------------------------- class ChunkedMaskedLinear(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor: ctx.save_for_backward(x, weight, active_chunks) ctx.has_bias = bias is not None ctx.sparse_dx = sparse_dx ctx.chunk_size = chunk_size return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_y: torch.Tensor): x, weight, active_chunks = ctx.saved_tensors chunk_size = ctx.chunk_size x_flat = x.reshape(-1, x.shape[-1]) gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) grad_w = torch.zeros_like(weight) grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None if ctx.sparse_dx: grad_x_flat = torch.zeros_like(x_flat) else: grad_x_flat = gy_flat @ weight # Zero-copy Strided Views feeding directly into Dense Hardware Matmuls for c_idx in active_chunks.tolist(): start = c_idx * chunk_size end = start + chunk_size gy_slice = gy_flat[:, start:end] w_slice = weight[start:end, :] grad_w[start:end, :] = gy_slice.t() @ x_flat if ctx.has_bias: grad_b[start:end] = gy_slice.sum(dim=0) if ctx.sparse_dx: grad_x_flat += gy_slice @ w_slice return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None class SparseLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__(in_features, out_features, bias=bias) self.sparse_enabled = False self.sparse_dx = False self.active_chunks: Optional[torch.Tensor] = None def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.sparse_enabled or self.active_chunks is None: return F.linear(x, self.weight, self.bias) return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx) # ----------------------------- # GPT Architecture # ----------------------------- class CausalSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.c_attn = SparseLinear(n_embd, 3 * n_embd) self.c_proj = SparseLinear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.c_fc = SparseLinear(n_embd, 4 * n_embd) self.c_proj = SparseLinear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = FeedForward(n_embd, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) # LM head is Dense! Needs full output dist for CrossEntropy loss self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def get_sparse_linears(model): return[m for m in model.modules() if isinstance(m, SparseLinear)] # ----------------------------- # Chunk Masker with Annealing # ----------------------------- class ChunkMasker: def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str): self.policy = policy self.target_fraction = target_fraction self.chunk_size = chunk_size self.device = device self.linears = get_sparse_linears(model) self.module_to_chunk_ids = {} offset = 0 for m in self.linears: assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}" n_chunks = m.out_features // chunk_size self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device) offset += n_chunks self.n_chunks = offset self.predicted_mass = torch.zeros(self.n_chunks, device=device) self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device) def choose_active(self, step: int, warmup_steps: int, anneal_steps: int): if step < warmup_steps: current_fraction = 1.0 elif step < warmup_steps + anneal_steps: progress = (step - warmup_steps) / anneal_steps cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress)) current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult else: current_fraction = self.target_fraction if current_fraction >= 0.999: self.active_chunks.fill_(True) for m, ids in self.module_to_chunk_ids.items(): m.active_chunks = torch.arange(len(ids), device=self.device) return k = max(1, int(current_fraction * self.n_chunks)) self.active_chunks.fill_(False) if self.policy == "random": self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True elif self.policy == "predicted_magnitude": scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass) self.active_chunks[torch.topk(scores, k=k).indices] = True for m, ids in self.module_to_chunk_ids.items(): global_active = self.active_chunks[ids] local_ids = torch.arange(len(ids), device=self.device) m.active_chunks = local_ids[global_active] @torch.no_grad() def update_predictor(self, mass_beta=0.95): current_mass = torch.zeros_like(self.predicted_mass) for m, ids in self.module_to_chunk_ids.items(): if m.weight.grad is None: continue w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2)) if m.bias is not None and m.bias.grad is not None: w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1) current_mass[ids] = torch.sqrt(w_sq + 1e-30) observed = self.active_chunks self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed] # ----------------------------- # Chunked Adam # ----------------------------- class ChunkedAdam: def __init__(self, model, lr=5e-4, chunk_size=64): self.model = model self.lr = lr self.chunk_size = chunk_size self.state = {} self.param_to_sparse_module = {} for m in get_sparse_linears(model): if m.weight is not None: self.param_to_sparse_module[m.weight] = m if m.bias is not None: self.param_to_sparse_module[m.bias] = m def zero_grad(self): for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self): for p in self.model.parameters(): if p.grad is None: continue if p not in self.state: self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"] sparse_module = self.param_to_sparse_module.get(p) active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None if active_chunks is None: # Dense update exp_avg.mul_(0.9).add_(p.grad, alpha=0.1) exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8) p.sub_(update, alpha=self.lr) else: # Sparse update for local_c in active_chunks.tolist(): start = local_c * self.chunk_size end = (local_c + 1) * self.chunk_size p_chunk = p[start:end] g_chunk = p.grad[start:end] m_chunk = exp_avg[start:end] v_chunk = exp_avg_sq[start:end] m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1) v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001) update = m_chunk / (torch.sqrt(v_chunk) + 1e-8) p_chunk.sub_(update, alpha=self.lr) # ----------------------------- # Training # ----------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--block_size", type=int, default=256) parser.add_argument("--n_layer", type=int, default=6) parser.add_argument("--n_head", type=int, default=8) parser.add_argument("--n_embd", type=int, default=512) parser.add_argument("--chunk_size", type=int, default=64) parser.add_argument("--active_fraction", type=float, default=0.10) parser.add_argument("--warmup_steps", type=int, default=50) parser.add_argument("--anneal_steps", type=int, default=200) parser.add_argument("--device", type=str, default="mps") parser.add_argument("--benchmark_sync", action="store_true") args = parser.parse_args() corpus = ShakespeareCorpus(args.block_size, args.device) modes =[ ("dense_baseline", "dense_baseline"), ("predicted_magnitude", "sparse_dW_full_dX"), ("predicted_magnitude", "sparse_dW_sparse_dX") ] print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size") print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active: {args.active_fraction*100}%") print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n") print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}") print("-" * 55) for policy, bwd_mode in modes: set_seed(42) model = GPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.1).to(args.device) for m in get_sparse_linears(model): m.chunk_size = args.chunk_size masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None opt = ChunkedAdam(model, lr=5e-4, chunk_size=args.chunk_size) if args.benchmark_sync: sync_device(args.device) t0 = time.perf_counter() measured_steps = args.steps for step in range(args.steps): if step == args.warmup_steps + args.anneal_steps: if args.benchmark_sync: sync_device(args.device) t0 = time.perf_counter() measured_steps = args.steps - step x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step)) if masker: masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps) for m in get_sparse_linears(model): m.sparse_enabled = True m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX") else: for m in get_sparse_linears(model): m.sparse_enabled = False m.active_chunks = None opt.zero_grad() _, loss = model(x, y) loss.backward() if masker: masker.update_predictor() opt.step() # Optional: Print progress every 100 steps if step % 200 == 0: print(f" [Progress] {bwd_mode} step {step}/{args.steps} | Loss: {loss.item():.4f}", end="\r") if args.benchmark_sync: sync_device(args.device) t_elapsed = time.perf_counter() - t0 # Eval loss model.eval() with torch.no_grad(): # Eval loss model.eval() with torch.no_grad(): val_x, val_y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999)) _, val_loss = model(val_x, val_y) # Clear the progress line print(" " * 60, end="\r") bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX") print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}") if __name__ == "__main__": main()