""" Sparse Transformer v6: stable predicted-magnitude masks, no dense refresh by default. This prototype is designed to test the next hypothesis after the spiral/MLP runs: The important gradient support is heavy-tailed and temporally stable enough that we can select active parameter blocks from history, freeze the rest, and still train a harder sequence model. Key fixes versus v5 ------------------- 1. Harder model: a small causal Transformer language model. 2. No periodic dense refresh by default: --warmup_steps 0. 3. The selector only learns from blocks it actually observes/updates. 4. Inactive Linear rows are truly frozen by MaskedAdam. This matters because ordinary Adam can still move parameters with zero gradients through momentum. 5. A true current-step oracle is included as an audit upper bound. 6. Random masks are included as a control. Important limitation -------------------- This still calls loss.backward(), so PyTorch computes dense gradients. Those full current gradients are used for audit metrics and for the oracle run only. The practical predicted_magnitude selector is not allowed to update its statistics from inactive full gradients. Actual speedup would require structured partial backward/custom kernels. Run --- python3 sparse_transformer_v6.py --quick python3 sparse_transformer_v6.py --steps 1000 --active_fractions 0.10 0.05 0.02 python3 sparse_transformer_v6.py --text_path input.txt --steps 2000 """ from __future__ import annotations import argparse import math import random from typing import Dict, List, Literal, Optional, Tuple import torch torch.set_num_threads(1) import torch.nn as nn import torch.nn.functional as F Policy = Literal["predicted_magnitude", "oracle_current", "random"] def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Data # ----------------------------- def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: rng = random.Random(seed) names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"] verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"] objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"] adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"] clauses = [ "when the loss falls", "after the mask shifts", "before the model answers", "while the signal drifts", "if the pattern repeats", "because the tail is noisy", ] symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"] lines: List[str] = [] for _ in range(n_sentences): t = rng.randrange(6) if t == 0: line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 1: line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}." elif t == 2: a, b = rng.sample(symbols, 2) line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}." elif t == 3: line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 4: seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7))) line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}." else: line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait." lines.append(line) return "\n".join(lines) + "\n" class CharCorpus: def __init__(self, text: str, block_size: int, device: str): chars = sorted(set(text)) self.stoi = {ch: i for i, ch in enumerate(chars)} self.itos = {i: ch for ch, i in self.stoi.items()} self.vocab_size = len(chars) self.block_size = block_size self.device = device data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) split = int(0.9 * len(data)) self.train_data = data[:split] self.val_data = data[split:] def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: data = self.train_data if split == "train" else self.val_data max_start = len(data) - self.block_size - 1 if max_start <= 0: raise ValueError("Corpus too small for block_size") ix = torch.randint(max_start, (batch_size,)) x = torch.stack([data[i : i + self.block_size] for i in ix]) y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) return x.to(self.device), y.to(self.device) def load_text(args: argparse.Namespace) -> str: if args.text_path: with open(args.text_path, "r", encoding="utf-8") as f: return f.read() return make_synthetic_corpus(args.synthetic_sentences, args.seed) # ----------------------------- # Mini GPT # ----------------------------- class CausalSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.c_attn = nn.Linear(n_embd, 3 * n_embd) self.c_proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.c_fc = nn.Linear(n_embd, 4 * n_embd) self.c_proj = nn.Linear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = FeedForward(n_embd, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class MiniGPT(nn.Module): def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.drop = nn.Dropout(dropout) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] x = self.drop(x) x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]: return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)] # ----------------------------- # Mask selector # ----------------------------- class RowMasker: def __init__( self, model: nn.Module, policy: Policy, active_fraction: float, explore_fraction: float, mass_beta: float, unobserved_decay: float, warmup_steps: int, device: str, ): self.model = model self.policy = policy self.active_fraction = active_fraction self.explore_fraction = explore_fraction self.mass_beta = mass_beta self.unobserved_decay = unobserved_decay self.warmup_steps = warmup_steps self.device = device self.linear_modules = [m for _, m in named_linear_modules(model)] self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {} ids = [] offset = 0 for m in self.linear_modules: n = m.weight.shape[0] block_ids = torch.arange(offset, offset + n, device=device) self.module_to_ids[m] = block_ids ids.append(block_ids) offset += n self.n_blocks = offset self.predicted_mass = torch.ones(self.n_blocks, device=device) self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.row_masks: Dict[nn.Linear, torch.Tensor] = {m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules} def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor: k = max(1, int(fraction * values.numel())) mask = torch.zeros_like(values, dtype=torch.bool) mask[torch.topk(values, k=k).indices] = True return mask @staticmethod def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: inter = (a & b).sum().float() union = (a | b).sum().float() return float((inter / torch.clamp(union, min=1.0)).item()) def _set_active(self, active: torch.Tensor) -> None: self.active = active self.row_masks = {} for m, ids in self.module_to_ids.items(): self.row_masks[m] = active[ids] def choose_pre_backward(self, step: int) -> None: if step < self.warmup_steps: self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)) return if self.policy == "oracle_current": # Cannot select until after current gradients are known. self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)) return k_total = max(1, int(self.active_fraction * self.n_blocks)) if self.policy == "random": active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device) active[torch.randperm(self.n_blocks, device=self.device)[:k_total]] = True self._set_active(active) return if self.policy != "predicted_magnitude": raise ValueError(f"Unknown policy: {self.policy}") k_explore = min(k_total, max(0, int(self.explore_fraction * k_total))) k_exploit = k_total - k_explore active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device) scores = self.predicted_mass + 1e-8 * torch.rand_like(self.predicted_mass) if k_exploit > 0: active[torch.topk(scores, k=k_exploit).indices] = True if k_explore > 0: remaining = torch.nonzero(~active, as_tuple=False).flatten() active[remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]] = True self._set_active(active) @torch.no_grad() def current_gradient_mass(self) -> torch.Tensor: mass = torch.zeros(self.n_blocks, device=self.device) for m, ids in self.module_to_ids.items(): if m.weight.grad is None: continue row_sq = m.weight.grad.square().sum(dim=1) if m.bias is not None and m.bias.grad is not None: row_sq = row_sq + m.bias.grad.square() mass[ids] = torch.sqrt(row_sq + 1e-30) return mass @torch.no_grad() def audit_and_update(self, step: int) -> Dict[str, float]: mass = self.current_gradient_mass() if step < self.warmup_steps: active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device) self._set_active(active) elif self.policy == "oracle_current": active = self._topk_mask(mass, self.active_fraction) self._set_active(active) else: active = self.active true_sq = mass.square().sum() approx_sq = mass[active].square().sum() cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) # With zero inactive blocks and active blocks using true gradient, cosine == norm ratio. norm_ratio = cosine oracle_mask = self._topk_mask(mass, self.active_fraction) jacc = self._jaccard(active, oracle_mask) stability = self._jaccard(active, self.prev_active) self.prev_active = active.clone() k20 = max(1, int(0.2 * self.n_blocks)) sorted_mass = torch.sort(mass, descending=True).values top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) # Strict rule: do not update stats from inactive full gradients. self.predicted_mass.mul_(self.unobserved_decay) observed = active self.predicted_mass[observed] = ( self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * mass[observed] ) return { "cosine": cosine, "norm_ratio": norm_ratio, "top20_mass": top20_mass, "jacc_oracle": jacc, "stability": stability, "active_fraction_real": float(active.float().mean().item()), } def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]: return self.row_masks.get(module) # ----------------------------- # Masked Adam # ----------------------------- class MaskedAdam: def __init__(self, model: nn.Module, masker: Optional[RowMasker], lr: float, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0): self.model = model self.masker = masker self.lr = lr self.beta1, self.beta2 = betas self.eps = eps self.weight_decay = weight_decay self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {} self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {} for _, m in named_linear_modules(model): self.linear_param[m.weight] = (m, "weight") if m.bias is not None: self.linear_param[m.bias] = (m, "bias") def zero_grad(self) -> None: for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self) -> None: for p in self.model.parameters(): if p.grad is None: continue if p not in self.state: self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} m = self.state[p]["m"] v = self.state[p]["v"] g = p.grad if self.weight_decay: g = g.add(p, alpha=self.weight_decay) row_mask = None if self.masker is not None and p in self.linear_param: module, kind = self.linear_param[p] base = self.masker.row_mask_for(module) if base is not None: row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base if row_mask is None: m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1) v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2) p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr) else: mask = row_mask.expand_as(p) if not bool(mask.any().item()): continue new_m = self.beta1 * m + (1.0 - self.beta1) * g new_v = self.beta2 * v + (1.0 - self.beta2) * g * g m[mask] = new_m[mask] v[mask] = new_v[mask] update = m / (torch.sqrt(v) + self.eps) p[mask] = p[mask] - self.lr * update[mask] # ----------------------------- # Training # ----------------------------- @torch.no_grad() def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]: model.eval() out = {} for split in ["train", "val"]: losses = [] for _ in range(eval_iters): x, y = corpus.get_batch(split, batch_size) _, loss = model(x, y) losses.append(float(loss.item())) out[split] = sum(losses) / len(losses) model.train() return out def train_run(corpus: CharCorpus, args: argparse.Namespace, policy: Optional[Policy], active_fraction: float, seed_offset: int) -> Dict[str, float | str]: set_seed(args.seed + seed_offset) dev = corpus.device model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) masker = None if policy is not None: masker = RowMasker( model=model, policy=policy, active_fraction=active_fraction, explore_fraction=args.explore_fraction, mass_beta=args.mass_beta, unobserved_decay=args.unobserved_decay, warmup_steps=args.warmup_steps, device=dev, ) opt = MaskedAdam(model, masker, lr=args.lr, weight_decay=args.weight_decay) sums = {"cosine": 0.0, "norm_ratio": 0.0, "top20_mass": 0.0, "jacc_oracle": 0.0, "stability": 0.0, "active_fraction_real": 0.0} count = 0 for step in range(args.steps): x, y = corpus.get_batch("train", args.batch_size) if masker is not None: masker.choose_pre_backward(step) _, loss = model(x, y) opt.zero_grad() loss.backward() if masker is not None: metrics = masker.audit_and_update(step) if step >= args.warmup_steps: for k in sums: sums[k] += metrics[k] count += 1 opt.step() if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1): losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) name = "dense" if policy is None else policy print(f"{name:20s} step={step:5d} train={losses['train']:.4f} val={losses['val']:.4f}") losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) row: Dict[str, float | str] = { "run": "dense_baseline" if policy is None else policy, "target_active": 1.0 if policy is None else active_fraction, "train_loss": losses["train"], "val_loss": losses["val"], } if masker is None or count == 0: row.update({"cosine": float("nan"), "norm_ratio": float("nan"), "top20_mass": float("nan"), "jacc_oracle": float("nan"), "stability": float("nan"), "active_fraction_real": 1.0}) else: for k, v in sums.items(): row[k] = v / count return row def print_summary(rows: List[Dict[str, float | str]]) -> None: print("\nSummary") header = f"{'run':>22s} {'target':>7s} {'actual':>7s} {'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} {'stable':>7s}" print(header) print("-" * len(header)) for r in rows: print( f"{str(r['run']):>22s} " f"{float(r['target_active']):7.3f} " f"{float(r['active_fraction_real']):7.3f} " f"{float(r['val_loss']):8.4f} " f"{float(r['train_loss']):8.4f} " f"{float(r['cosine']):7.3f} " f"{float(r['top20_mass']):7.3f} " f"{float(r['jacc_oracle']):7.3f} " f"{float(r['stability']):7.3f}" ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--text_path", type=str, default=None) p.add_argument("--synthetic_sentences", type=int, default=12000) p.add_argument("--steps", type=int, default=1000) p.add_argument("--quick", action="store_true") p.add_argument("--batch_size", type=int, default=32) p.add_argument("--block_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=2) p.add_argument("--n_head", type=int, default=4) p.add_argument("--n_embd", type=int, default=64) p.add_argument("--dropout", type=float, default=0.0) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--weight_decay", type=float, default=0.0) p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02]) p.add_argument("--explore_fraction", type=float, default=0.10) p.add_argument("--mass_beta", type=float, default=0.95) p.add_argument("--unobserved_decay", type=float, default=0.999) p.add_argument("--warmup_steps", type=int, default=0) p.add_argument("--eval_interval", type=int, default=200) p.add_argument("--eval_iters", type=int, default=20) p.add_argument("--seed", type=int, default=7) p.add_argument("--verbose", action="store_true") return p.parse_args() def main() -> None: args = parse_args() if args.quick: args.steps = 60 args.eval_iters = 3 args.batch_size = 16 args.block_size = 32 args.n_layer = 1 args.n_embd = 32 args.n_head = 4 args.synthetic_sentences = 2000 args.active_fractions = [0.10, 0.02] set_seed(args.seed) dev = device() print(f"device={dev}") corpus = CharCorpus(load_text(args), args.block_size, dev) print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}") print(f"warmup_steps={args.warmup_steps} explore_fraction={args.explore_fraction}") rows: List[Dict[str, float | str]] = [] print("\nRunning dense baseline") rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, seed_offset=0)) seed_offset = 100 for af in args.active_fractions: for policy in ["oracle_current", "predicted_magnitude", "random"]: print(f"\nRunning policy={policy}, active_fraction={af:.3f}") rows.append(train_run(corpus, args, policy=policy, active_fraction=af, seed_offset=seed_offset)) seed_offset += 1 print_summary(rows) print("\nNotes") print(" oracle_current uses the current full gradient to choose rows; it is an upper bound, not a practical selector.") print(" predicted_magnitude chooses from EMA mass only, plus a small random exploration budget.") print(" EMA mass is updated only for active/observed rows, not all rows.") print(" inactive Linear rows are frozen by MaskedAdam, including Adam state; zero grad alone is not enough.") print(" dense gradients are still computed for audit, so this is not a wall-clock speed benchmark yet.") if __name__ == "__main__": main()