Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | """ | |
| Sparse Transformer: Real-World Benchmark on Tiny Shakespeare using GPT-2 BPE. | |
| This script scales the architecture to a 6-layer, 512-dim GPT and trains on | |
| real natural language. It applies our Hardware-Sympathetic Chunked Sparse | |
| backward pass, Cosine Annealing, and Chunked Adam optimizer. | |
| Run: | |
| python3 sparse_transformer_shakespeare.py --device mps --benchmark_sync | |
| """ | |
| import argparse | |
| import math | |
| import os | |
| import random | |
| import time | |
| import urllib.request | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError("Please install tiktoken: pip install tiktoken") | |
| torch.set_num_threads(1) | |
| def sync_device(device: str) -> None: | |
| if device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elif device == "mps" and hasattr(torch, "mps"): | |
| torch.mps.synchronize() | |
| Policy = Literal["predicted_magnitude", "oracle_current", "random"] | |
| BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| def make_cpu_generator(seed: int) -> torch.Generator: | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed(seed) | |
| return gen | |
| # ----------------------------- | |
| # Real-World Data Pipeline | |
| # ----------------------------- | |
| class ShakespeareCorpus: | |
| def __init__(self, block_size: int, device: str): | |
| self.block_size = block_size | |
| self.device = device | |
| # 1. Download Tiny Shakespeare if not exists | |
| data_path = "input.txt" | |
| if not os.path.exists(data_path): | |
| print("Downloading Tiny Shakespeare...") | |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | |
| urllib.request.urlretrieve(url, data_path) | |
| # 2. Tokenize using GPT-2 BPE | |
| print("Tokenizing data...") | |
| with open(data_path, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| enc = tiktoken.get_encoding("gpt2") | |
| tokens = enc.encode(text) | |
| self.vocab_size = enc.n_vocab | |
| # 3. Split 90/10 Train/Val | |
| data = torch.tensor(tokens, dtype=torch.long) | |
| split_idx = int(0.9 * len(data)) | |
| self.train_data = data[:split_idx] | |
| self.val_data = data[split_idx:] | |
| print(f"Dataset loaded. Vocab size: {self.vocab_size:,}. Train tokens: {len(self.train_data):,}") | |
| def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| data = self.train_data if split == "train" else self.val_data | |
| ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator) | |
| x = torch.stack([data[i : i + self.block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) | |
| return x.to(self.device), y.to(self.device) | |
| # ----------------------------- | |
| # Chunked Sparse Autograd | |
| # ----------------------------- | |
| class ChunkedMaskedLinear(torch.autograd.Function): | |
| def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor: | |
| ctx.save_for_backward(x, weight, active_chunks) | |
| ctx.has_bias = bias is not None | |
| ctx.sparse_dx = sparse_dx | |
| ctx.chunk_size = chunk_size | |
| return F.linear(x, weight, bias) | |
| def backward(ctx, grad_y: torch.Tensor): | |
| x, weight, active_chunks = ctx.saved_tensors | |
| chunk_size = ctx.chunk_size | |
| x_flat = x.reshape(-1, x.shape[-1]) | |
| gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) | |
| grad_w = torch.zeros_like(weight) | |
| grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None | |
| if ctx.sparse_dx: | |
| grad_x_flat = torch.zeros_like(x_flat) | |
| else: | |
| grad_x_flat = gy_flat @ weight | |
| # Zero-copy Strided Views feeding directly into Dense Hardware Matmuls | |
| for c_idx in active_chunks.tolist(): | |
| start = c_idx * chunk_size | |
| end = start + chunk_size | |
| gy_slice = gy_flat[:, start:end] | |
| w_slice = weight[start:end, :] | |
| grad_w[start:end, :] = gy_slice.t() @ x_flat | |
| if ctx.has_bias: | |
| grad_b[start:end] = gy_slice.sum(dim=0) | |
| if ctx.sparse_dx: | |
| grad_x_flat += gy_slice @ w_slice | |
| return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None | |
| class SparseLinear(nn.Linear): | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__(in_features, out_features, bias=bias) | |
| self.sparse_enabled = False | |
| self.sparse_dx = False | |
| self.active_chunks: Optional[torch.Tensor] = None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if not self.sparse_enabled or self.active_chunks is None: | |
| return F.linear(x, self.weight, self.bias) | |
| return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx) | |
| # ----------------------------- | |
| # GPT Architecture | |
| # ----------------------------- | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = SparseLinear(n_embd, 3 * n_embd) | |
| self.c_proj = SparseLinear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.c_fc = SparseLinear(n_embd, 4 * n_embd) | |
| self.c_proj = SparseLinear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = FeedForward(n_embd, dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| # LM head is Dense! Needs full output dist for CrossEntropy loss | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.shape | |
| pos = torch.arange(T, device=idx.device) | |
| x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def get_sparse_linears(model): | |
| return[m for m in model.modules() if isinstance(m, SparseLinear)] | |
| # ----------------------------- | |
| # Chunk Masker with Annealing | |
| # ----------------------------- | |
| class ChunkMasker: | |
| def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str): | |
| self.policy = policy | |
| self.target_fraction = target_fraction | |
| self.chunk_size = chunk_size | |
| self.device = device | |
| self.linears = get_sparse_linears(model) | |
| self.module_to_chunk_ids = {} | |
| offset = 0 | |
| for m in self.linears: | |
| assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}" | |
| n_chunks = m.out_features // chunk_size | |
| self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device) | |
| offset += n_chunks | |
| self.n_chunks = offset | |
| self.predicted_mass = torch.zeros(self.n_chunks, device=device) | |
| self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device) | |
| def choose_active(self, step: int, warmup_steps: int, anneal_steps: int): | |
| if step < warmup_steps: | |
| current_fraction = 1.0 | |
| elif step < warmup_steps + anneal_steps: | |
| progress = (step - warmup_steps) / anneal_steps | |
| cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult | |
| else: | |
| current_fraction = self.target_fraction | |
| if current_fraction >= 0.999: | |
| self.active_chunks.fill_(True) | |
| for m, ids in self.module_to_chunk_ids.items(): | |
| m.active_chunks = torch.arange(len(ids), device=self.device) | |
| return | |
| k = max(1, int(current_fraction * self.n_chunks)) | |
| self.active_chunks.fill_(False) | |
| if self.policy == "random": | |
| self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True | |
| elif self.policy == "predicted_magnitude": | |
| scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass) | |
| self.active_chunks[torch.topk(scores, k=k).indices] = True | |
| for m, ids in self.module_to_chunk_ids.items(): | |
| global_active = self.active_chunks[ids] | |
| local_ids = torch.arange(len(ids), device=self.device) | |
| m.active_chunks = local_ids[global_active] | |
| def update_predictor(self, mass_beta=0.95): | |
| current_mass = torch.zeros_like(self.predicted_mass) | |
| for m, ids in self.module_to_chunk_ids.items(): | |
| if m.weight.grad is None: continue | |
| w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2)) | |
| if m.bias is not None and m.bias.grad is not None: | |
| w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1) | |
| current_mass[ids] = torch.sqrt(w_sq + 1e-30) | |
| observed = self.active_chunks | |
| self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed] | |
| # ----------------------------- | |
| # Chunked Adam | |
| # ----------------------------- | |
| class ChunkedAdam: | |
| def __init__(self, model, lr=5e-4, chunk_size=64): | |
| self.model = model | |
| self.lr = lr | |
| self.chunk_size = chunk_size | |
| self.state = {} | |
| self.param_to_sparse_module = {} | |
| for m in get_sparse_linears(model): | |
| if m.weight is not None: self.param_to_sparse_module[m.weight] = m | |
| if m.bias is not None: self.param_to_sparse_module[m.bias] = m | |
| def zero_grad(self): | |
| for p in self.model.parameters(): p.grad = None | |
| def step(self): | |
| for p in self.model.parameters(): | |
| if p.grad is None: continue | |
| if p not in self.state: | |
| self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} | |
| exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"] | |
| sparse_module = self.param_to_sparse_module.get(p) | |
| active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None | |
| if active_chunks is None: | |
| # Dense update | |
| exp_avg.mul_(0.9).add_(p.grad, alpha=0.1) | |
| exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) | |
| update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8) | |
| p.sub_(update, alpha=self.lr) | |
| else: | |
| # Sparse update | |
| for local_c in active_chunks.tolist(): | |
| start = local_c * self.chunk_size | |
| end = (local_c + 1) * self.chunk_size | |
| p_chunk = p[start:end] | |
| g_chunk = p.grad[start:end] | |
| m_chunk = exp_avg[start:end] | |
| v_chunk = exp_avg_sq[start:end] | |
| m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1) | |
| v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001) | |
| update = m_chunk / (torch.sqrt(v_chunk) + 1e-8) | |
| p_chunk.sub_(update, alpha=self.lr) | |
| # ----------------------------- | |
| # Training | |
| # ----------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--steps", type=int, default=1000) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--block_size", type=int, default=256) | |
| parser.add_argument("--n_layer", type=int, default=6) | |
| parser.add_argument("--n_head", type=int, default=8) | |
| parser.add_argument("--n_embd", type=int, default=512) | |
| parser.add_argument("--chunk_size", type=int, default=64) | |
| parser.add_argument("--active_fraction", type=float, default=0.10) | |
| parser.add_argument("--warmup_steps", type=int, default=50) | |
| parser.add_argument("--anneal_steps", type=int, default=200) | |
| parser.add_argument("--device", type=str, default="mps") | |
| parser.add_argument("--benchmark_sync", action="store_true") | |
| args = parser.parse_args() | |
| corpus = ShakespeareCorpus(args.block_size, args.device) | |
| modes =[ | |
| ("dense_baseline", "dense_baseline"), | |
| ("predicted_magnitude", "sparse_dW_full_dX"), | |
| ("predicted_magnitude", "sparse_dW_sparse_dX") | |
| ] | |
| print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size") | |
| print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active: {args.active_fraction*100}%") | |
| print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n") | |
| print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}") | |
| print("-" * 55) | |
| for policy, bwd_mode in modes: | |
| set_seed(42) | |
| model = GPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.1).to(args.device) | |
| for m in get_sparse_linears(model): | |
| m.chunk_size = args.chunk_size | |
| masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None | |
| opt = ChunkedAdam(model, lr=5e-4, chunk_size=args.chunk_size) | |
| if args.benchmark_sync: sync_device(args.device) | |
| t0 = time.perf_counter() | |
| measured_steps = args.steps | |
| for step in range(args.steps): | |
| if step == args.warmup_steps + args.anneal_steps: | |
| if args.benchmark_sync: sync_device(args.device) | |
| t0 = time.perf_counter() | |
| measured_steps = args.steps - step | |
| x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step)) | |
| if masker: | |
| masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps) | |
| for m in get_sparse_linears(model): | |
| m.sparse_enabled = True | |
| m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX") | |
| else: | |
| for m in get_sparse_linears(model): | |
| m.sparse_enabled = False | |
| m.active_chunks = None | |
| opt.zero_grad() | |
| _, loss = model(x, y) | |
| loss.backward() | |
| if masker: | |
| masker.update_predictor() | |
| opt.step() | |
| # Optional: Print progress every 100 steps | |
| if step % 200 == 0: | |
| print(f" [Progress] {bwd_mode} step {step}/{args.steps} | Loss: {loss.item():.4f}", end="\r") | |
| if args.benchmark_sync: sync_device(args.device) | |
| t_elapsed = time.perf_counter() - t0 | |
| # Eval loss | |
| model.eval() | |
| with torch.no_grad(): | |
| # Eval loss | |
| model.eval() | |
| with torch.no_grad(): | |
| val_x, val_y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999)) | |
| _, val_loss = model(val_x, val_y) | |
| # Clear the progress line | |
| print(" " * 60, end="\r") | |
| bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX") | |
| print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}") | |
| if __name__ == "__main__": | |
| main() |