""" Sparse Transformer v8: from masked-optimizer simulation to real sparse Linear backward. v7 showed that Transformer Linear-row gradient support is heavy-tailed and stable, and that a practical EMA selector can nearly match an oracle selector after a tiny warmup. But v7 still computed dense gradients and only masked the optimizer step. v8 tests the next question: Can the sparse row mask be moved into the Linear backward pass itself? Backward modes -------------- 1. masked_optimizer v7-style control. Compute dense backward, but MaskedAdam only updates active Linear rows. This should match the previous simulation behavior. 2. sparse_dW_full_dX Custom autograd Linear computes grad_weight / grad_bias only for active output rows, while still propagating full grad_input backward. This is the conservative real-backward mode. It targets the dW part of Linear backward only. 3. sparse_dW_sparse_dX Custom autograd Linear computes grad_weight only for active rows and also propagates grad_input only through active output rows. This is the aggressive mode. It may save more backward compute in a real kernel, but it can damage upstream learning. Important caveat ---------------- This script still performs a dense audit backward pass each training step to: - compute oracle metrics, - support oracle_current and stale_current controls, - update practical EMA statistics only for active/observed rows. The actual training update in sparse_dW_* modes comes from the custom sparse backward pass, not from the dense audit gradients. This is a correctness and semantics experiment, not a wall-clock benchmark. Example ------- Smoke test: python3 sparse_transformer_v8.py --quick Main comparison: python3 sparse_transformer_v8.py \ --steps 2000 \ --active_fractions 0.05 0.02 \ --warmup_steps_list 5 \ --explore_fractions 0.00 \ --policies oracle_current predicted_magnitude random \ --backward_modes masked_optimizer sparse_dW_full_dX sparse_dW_sparse_dX """ from __future__ import annotations import argparse import math import random from typing import Dict, List, Literal, Optional, Tuple import torch torch.set_num_threads(1) import torch.nn as nn import torch.nn.functional as F Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"] BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] # ----------------------------- # Reproducibility and device # ----------------------------- def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def make_cpu_generator(seed: int) -> torch.Generator: gen = torch.Generator(device="cpu") gen.manual_seed(seed) return gen # ----------------------------- # Data # ----------------------------- def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: rng = random.Random(seed) names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"] verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"] objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"] adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"] clauses = [ "when the loss falls", "after the mask shifts", "before the model answers", "while the signal drifts", "if the pattern repeats", "because the tail is noisy", ] symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"] lines: List[str] = [] for _ in range(n_sentences): t = rng.randrange(6) if t == 0: line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 1: line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}." elif t == 2: a, b = rng.sample(symbols, 2) line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}." elif t == 3: line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 4: seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7))) line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}." else: line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait." lines.append(line) return "\n".join(lines) + "\n" class CharCorpus: def __init__(self, text: str, block_size: int, device: str): chars = sorted(set(text)) self.stoi = {ch: i for i, ch in enumerate(chars)} self.itos = {i: ch for ch, i in self.stoi.items()} self.vocab_size = len(chars) self.block_size = block_size self.device = device data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) split = int(0.9 * len(data)) self.train_data = data[:split] self.val_data = data[split:] def get_batch( self, split: str, batch_size: int, generator: Optional[torch.Generator] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: data = self.train_data if split == "train" else self.val_data max_start = len(data) - self.block_size - 1 if max_start <= 0: raise ValueError("Corpus too small for block_size") ix = torch.randint(max_start, (batch_size,), generator=generator) x = torch.stack([data[i : i + self.block_size] for i in ix]) y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) return x.to(self.device), y.to(self.device) def load_text(args: argparse.Namespace) -> str: if args.text_path: with open(args.text_path, "r", encoding="utf-8") as f: return f.read() return make_synthetic_corpus(args.synthetic_sentences, args.seed) # ----------------------------- # Sparse Linear autograd # ----------------------------- class MaskedLinearFunction(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_rows: torch.Tensor, sparse_dx: bool, ) -> torch.Tensor: ctx.save_for_backward(x, weight, active_rows) ctx.has_bias = bias is not None ctx.sparse_dx = bool(sparse_dx) return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_y: torch.Tensor): # type: ignore[override] x, weight, active_rows = ctx.saved_tensors sparse_dx = bool(ctx.sparse_dx) has_bias = bool(ctx.has_bias) x_shape = x.shape x_flat = x.reshape(-1, x.shape[-1]) gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) active_idx = torch.nonzero(active_rows, as_tuple=False).flatten() grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None if active_idx.numel() > 0: gy_active = gy_flat[:, active_idx] grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat if grad_bias is not None: grad_bias[active_idx] = gy_active.sum(dim=0) if sparse_dx: grad_x_flat = gy_active @ weight[active_idx] else: grad_x_flat = gy_flat @ weight else: # This can happen when a global top-k mask selects no rows from a # particular layer. Conservative full_dX still propagates through that # layer; aggressive sparse_dX cuts it off for that layer. if sparse_dx: grad_x_flat = torch.zeros_like(x_flat) else: grad_x_flat = gy_flat @ weight grad_x = grad_x_flat.reshape(x_shape) return grad_x, grad_weight, grad_bias, None, None class SparseLinear(nn.Linear): """nn.Linear with an optional row-sparse backward pass.""" def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__(in_features, out_features, bias=bias) self.sparse_enabled = False self.sparse_dx = False self.active_rows: Optional[torch.Tensor] = None def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None: self.sparse_enabled = bool(enabled) self.sparse_dx = bool(sparse_dx) self.active_rows = active_rows def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.sparse_enabled or self.active_rows is None: return F.linear(x, self.weight, self.bias) return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx) # ----------------------------- # Mini GPT # ----------------------------- class CausalSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.c_attn = SparseLinear(n_embd, 3 * n_embd) self.c_proj = SparseLinear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.c_fc = SparseLinear(n_embd, 4 * n_embd) self.c_proj = SparseLinear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = FeedForward(n_embd, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class MiniGPT(nn.Module): def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.drop = nn.Dropout(dropout) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = SparseLinear(n_embd, vocab_size) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] x = self.drop(x) x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]: return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)] def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]: total = sum(p.numel() for p in model.parameters()) linear = 0 for _, m in named_sparse_linear_modules(model): linear += m.weight.numel() if m.bias is not None: linear += m.bias.numel() return total, linear, linear / max(1, total) def configure_sparse_linears( model: nn.Module, masker: Optional["RowMasker"], enabled: bool, backward_mode: Optional[str], ) -> None: sparse_dx = backward_mode == "sparse_dW_sparse_dX" for _, m in named_sparse_linear_modules(model): active = masker.row_mask_for(m) if masker is not None else None m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx) # ----------------------------- # Mask selector # ----------------------------- class RowMasker: def __init__( self, model: nn.Module, policy: Policy, active_fraction: float, explore_fraction: float, mass_beta: float, unobserved_decay: float, warmup_steps: int, ucb_alpha: float, mass_init: float, device: str, ): self.model = model self.policy = policy self.active_fraction = active_fraction self.explore_fraction = explore_fraction self.mass_beta = mass_beta self.unobserved_decay = unobserved_decay self.warmup_steps = warmup_steps self.ucb_alpha = ucb_alpha self.mass_init = mass_init self.device = device self.step_index = 0 self.linear_modules = [m for _, m in named_sparse_linear_modules(model)] self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {} ids = [] offset = 0 for m in self.linear_modules: n = m.weight.shape[0] block_ids = torch.arange(offset, offset + n, device=device) self.module_to_ids[m] = block_ids ids.append(block_ids) offset += n self.n_blocks = offset self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device) self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device) self.observed_count = torch.zeros(self.n_blocks, device=device) self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device) self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.row_masks: Dict[SparseLinear, torch.Tensor] = { m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules } def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor: k = max(1, int(fraction * values.numel())) mask = torch.zeros_like(values, dtype=torch.bool) noisy = values + 1e-9 * torch.rand_like(values) mask[torch.topk(noisy, k=k).indices] = True return mask @staticmethod def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: inter = (a & b).sum().float() union = (a | b).sum().float() return float((inter / torch.clamp(union, min=1.0)).item()) def _set_active(self, active: torch.Tensor) -> None: self.active = active self.row_masks = {} for m, ids in self.module_to_ids.items(): self.row_masks[m] = active[ids] def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor: n = self.n_blocks k_total = max(1, int(self.active_fraction * n)) k_explore = min(k_total, max(0, int(self.explore_fraction * k_total))) k_exploit = k_total - k_explore active = torch.zeros(n, dtype=torch.bool, device=self.device) if k_exploit > 0: active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True if k_explore > 0: remaining = torch.nonzero(~active, as_tuple=False).flatten() pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]] active[pick] = True return active def choose_pre_backward(self, step: int) -> None: self.step_index = step if step < self.warmup_steps: self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)) return if self.policy == "oracle_current": # Oracle cannot choose until the dense audit gradient is known. self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)) return if self.policy == "random": self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device))) return if self.policy == "stale_current": self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction)) return if self.policy == "predicted_magnitude": self._set_active(self._sample_exploit_explore(self.predicted_mass)) return if self.policy == "ucb_magnitude": t = max(1, step - self.warmup_steps + 1) log_term = torch.log(torch.tensor(float(t + 2), device=self.device)) bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8) bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0)) self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus)) return raise ValueError(f"Unknown policy: {self.policy}") @torch.no_grad() def current_gradient_mass_from_grads(self) -> torch.Tensor: mass = torch.zeros(self.n_blocks, device=self.device) for m, ids in self.module_to_ids.items(): if m.weight.grad is None: continue row_sq = m.weight.grad.square().sum(dim=1) if m.bias is not None and m.bias.grad is not None: row_sq = row_sq + m.bias.grad.square() mass[ids] = torch.sqrt(row_sq + 1e-30) return mass @torch.no_grad() def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]: if step < self.warmup_steps: active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device) self._set_active(active) elif self.policy == "oracle_current": active = self._topk_mask(mass, self.active_fraction) self._set_active(active) else: active = self.active true_sq = mass.square().sum() approx_sq = mass[active].square().sum() cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) oracle_mask = self._topk_mask(mass, self.active_fraction) jacc = self._jaccard(active, oracle_mask) stability = self._jaccard(active, self.prev_active) self.prev_active = active.clone() k20 = max(1, int(0.2 * self.n_blocks)) sorted_mass = torch.sort(mass, descending=True).values top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) new_active = active & (self.observed_count == 0) # Practical rule: update predicted statistics only for active/observed rows. self.predicted_mass.mul_(self.unobserved_decay) observed = active if bool(observed.any().item()): obs_mass = mass[observed] first_seen = self.observed_count[observed] == 0 ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass) self.observed_count[observed] += 1.0 self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean() # Dense audit signal; only stale_current is allowed to use this for selection. self.last_full_mass = mass.detach().clone() return { "cosine": cosine, "norm_ratio": cosine, "top20_mass": top20_mass, "jacc_oracle": jacc, "stability": stability, "active_fraction_real": float(active.float().mean().item()), "coverage": float((self.observed_count > 0).float().mean().item()), "avg_obs_count": float(self.observed_count.mean().item()), "new_active_fraction": float(new_active.float().mean().item()), } def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]: return self.row_masks.get(module) # ----------------------------- # Masked Adam # ----------------------------- class MaskedAdam: def __init__( self, model: nn.Module, masker: Optional[RowMasker], lr: float, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0, freeze_non_linear_when_sparse: bool = False, ): self.model = model self.masker = masker self.lr = lr self.beta1, self.beta2 = betas self.eps = eps self.weight_decay = weight_decay self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {} self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {} for _, m in named_sparse_linear_modules(model): self.linear_param[m.weight] = (m, "weight") if m.bias is not None: self.linear_param[m.bias] = (m, "bias") def zero_grad(self) -> None: for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self) -> None: for p in self.model.parameters(): if p.grad is None: continue if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param: continue if p not in self.state: self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} m = self.state[p]["m"] v = self.state[p]["v"] g = p.grad if self.weight_decay: g = g.add(p, alpha=self.weight_decay) row_mask = None if self.masker is not None and p in self.linear_param: module, kind = self.linear_param[p] base = self.masker.row_mask_for(module) if base is not None: row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base if row_mask is None: m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1) v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2) p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr) else: mask = row_mask.expand_as(p) if not bool(mask.any().item()): continue new_m = self.beta1 * m + (1.0 - self.beta1) * g new_v = self.beta2 * v + (1.0 - self.beta2) * g * g m[mask] = new_m[mask] v[mask] = new_v[mask] update = m / (torch.sqrt(v) + self.eps) p[mask] = p[mask] - self.lr * update[mask] # ----------------------------- # Training utilities # ----------------------------- @torch.no_grad() def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]: model.eval() configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) out = {} for split in ["train", "val"]: losses = [] gen = make_cpu_generator(seed + (0 if split == "train" else 100000)) for _ in range(eval_iters): x, y = corpus.get_batch(split, batch_size, generator=gen) _, loss = model(x, y) losses.append(float(loss.item())) out[split] = sum(losses) / len(losses) model.train() return out def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor: x, y = corpus_batch configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) opt.zero_grad() _, audit_loss = model(x, y) audit_loss.backward() mass = masker.current_gradient_mass_from_grads() opt.zero_grad() return mass def sparse_training_backward( model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: Optional[RowMasker], backward_mode: Optional[BackwardMode], ) -> float: x, y = corpus_batch opt.zero_grad() if masker is None or backward_mode is None or backward_mode == "masked_optimizer": configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) else: configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode) _, loss = model(x, y) loss.backward() configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) return float(loss.item()) def train_run( corpus: CharCorpus, args: argparse.Namespace, policy: Optional[Policy], backward_mode: Optional[BackwardMode], active_fraction: float, warmup_steps: int, explore_fraction: float, seed_offset: int, ) -> Dict[str, float | str]: # Same model initialization and same minibatch sequence for every run by default. set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0)) data_gen = make_cpu_generator(args.seed + 12345) dev = corpus.device model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) masker = None if policy is not None: masker = RowMasker( model=model, policy=policy, active_fraction=active_fraction, explore_fraction=explore_fraction, mass_beta=args.mass_beta, unobserved_decay=args.unobserved_decay, warmup_steps=warmup_steps, ucb_alpha=args.ucb_alpha, mass_init=args.mass_init, device=dev, ) opt = MaskedAdam( model, masker, lr=args.lr, weight_decay=args.weight_decay, freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse, ) sums = { "cosine": 0.0, "norm_ratio": 0.0, "top20_mass": 0.0, "jacc_oracle": 0.0, "stability": 0.0, "active_fraction_real": 0.0, "coverage": 0.0, "avg_obs_count": 0.0, "new_active_fraction": 0.0, } count = 0 for step in range(args.steps): batch = corpus.get_batch("train", args.batch_size, generator=data_gen) if masker is None: loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None) opt.step() else: masker.choose_pre_backward(step) full_mass = dense_audit_pass(model, batch, opt, masker) metrics = masker.audit_and_update_from_mass(step, full_mass) if step >= warmup_steps: for k in sums: sums[k] += metrics[k] count += 1 loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode) opt.step() if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1): losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555) name = "dense" if policy is None else f"{policy}/{backward_mode}" print( f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} " f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}" ) losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999) row: Dict[str, float | str] = { "run": "dense_baseline" if policy is None else policy, "mode": "dense" if backward_mode is None else backward_mode, "target_active": 1.0 if policy is None else active_fraction, "warmup": warmup_steps, "explore": explore_fraction if policy is not None else 0.0, "train_loss": losses["train"], "val_loss": losses["val"], } if masker is None or count == 0: row.update({ "cosine": float("nan"), "norm_ratio": float("nan"), "top20_mass": float("nan"), "jacc_oracle": float("nan"), "stability": float("nan"), "active_fraction_real": 1.0, "coverage": float("nan"), "avg_obs_count": float("nan"), "new_active_fraction": float("nan"), }) else: for k, v in sums.items(): row[k] = v / count return row def print_summary(rows: List[Dict[str, float | str]]) -> None: print("\nSummary") header = ( f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} " f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} " f"{'stable':>7s} {'cover':>7s} {'new':>7s}" ) print(header) print("-" * len(header)) for r in rows: print( f"{str(r['run']):>22s} " f"{str(r['mode']):>19s} " f"{float(r['target_active']):7.3f} " f"{float(r['active_fraction_real']):7.3f} " f"{int(float(r['warmup'])):5d} " f"{float(r['explore']):5.2f} " f"{float(r['val_loss']):8.4f} " f"{float(r['train_loss']):8.4f} " f"{float(r['cosine']):7.3f} " f"{float(r['top20_mass']):7.3f} " f"{float(r['jacc_oracle']):7.3f} " f"{float(r['stability']):7.3f} " f"{float(r['coverage']):7.3f} " f"{float(r['new_active_fraction']):7.3f}" ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--text_path", type=str, default=None) p.add_argument("--synthetic_sentences", type=int, default=12000) p.add_argument("--steps", type=int, default=1000) p.add_argument("--quick", action="store_true") p.add_argument("--batch_size", type=int, default=32) p.add_argument("--block_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=2) p.add_argument("--n_head", type=int, default=4) p.add_argument("--n_embd", type=int, default=64) p.add_argument("--dropout", type=float, default=0.0) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--weight_decay", type=float, default=0.0) p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02]) p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"]) p.add_argument( "--backward_modes", type=str, nargs="+", default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"], ) p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0]) p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5]) p.add_argument("--mass_beta", type=float, default=0.95) p.add_argument("--unobserved_decay", type=float, default=1.0) p.add_argument("--mass_init", type=float, default=0.0) p.add_argument("--ucb_alpha", type=float, default=1.0) p.add_argument("--freeze_non_linear_when_sparse", action="store_true") p.add_argument("--eval_interval", type=int, default=200) p.add_argument("--eval_iters", type=int, default=20) p.add_argument("--seed", type=int, default=7) p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.") p.add_argument("--verbose", action="store_true") return p.parse_args() def main() -> None: args = parse_args() if args.quick: args.steps = 40 args.eval_iters = 2 args.batch_size = 8 args.block_size = 32 args.n_layer = 1 args.n_embd = 32 args.n_head = 4 args.synthetic_sentences = 1200 args.active_fractions = [0.05] args.policies = ["predicted_magnitude", "random"] args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] args.explore_fractions = [0.0] args.warmup_steps_list = [5] valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"} valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"} for pol in args.policies: if pol not in valid_policies: raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}") for mode in args.backward_modes: if mode not in valid_modes: raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}") set_seed(args.seed) dev = device() print(f"device={dev}") corpus = CharCorpus(load_text(args), args.block_size, dev) print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}") print(f"policies={args.policies}") print(f"backward_modes={args.backward_modes}") print(f"active_fractions={args.active_fractions}") print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}") print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}") print(f"paired_seeds={not args.unpaired_seeds}") tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) total_params, linear_params, linear_frac = parameter_fractions(tmp_model) del tmp_model print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}") if args.freeze_non_linear_when_sparse: print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs") else: print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely") if args.dropout != 0.0: print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks") rows: List[Dict[str, float | str]] = [] print("\nRunning dense baseline") rows.append( train_run( corpus, args, policy=None, backward_mode=None, active_fraction=1.0, warmup_steps=0, explore_fraction=0.0, seed_offset=0, ) ) seed_offset = 100 for mode in args.backward_modes: for af in args.active_fractions: for pol in args.policies: explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0] for warmup in args.warmup_steps_list: for explore in explore_values: print( f"\nRunning mode={mode}, policy={pol}, " f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}" ) rows.append( train_run( corpus, args, policy=pol, # type: ignore[arg-type] backward_mode=mode, # type: ignore[arg-type] active_fraction=af, warmup_steps=warmup, explore_fraction=explore, seed_offset=seed_offset, ) ) seed_offset += 1 print_summary(rows) print("\nNotes") print(" masked_optimizer is the v7-style dense-backward simulation control.") print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.") print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.") print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.") print(" predicted_magnitude uses EMA mass from active/observed rows only.") print(" random is the sparse-support control.") print(" dense audit gradients are still computed every step for metrics/control; this is not a speed benchmark.") print(" The key comparison is masked_optimizer vs sparse_dW_full_dX. If they match, the v7 effect survives real dW sparsification.") if __name__ == "__main__": main()