""" Sparse Transformer v7: discovery stress tests for stable gradient-support masks. This version follows the v6 result: oracle_current works far better than random, so useful sparse support exists; predicted_magnitude without warmup does not reliably discover that support. v7 focuses on discovery mechanisms: 1. predicted_magnitude Exploit rows with the largest EMA-observed gradient mass. 2. ucb_magnitude A bandit-style selector: EMA mass + an uncertainty bonus for under-observed rows. This is meant to discover useful rows without dense refresh. First observation initializes EMA scale immediately. 3. stale_current A renamed diagnostic control: use the previous full-gradient mass. It is not practical because it relies on dense audit gradients, but it tells us whether one-step lag is too noisy. 4. oracle_current True current top-k by dense gradient mass. Upper bound only. 5. random Control. Important limitation -------------------- This still calls loss.backward(), so PyTorch computes dense gradients. Dense current gradients are used for audit metrics and for oracle/stale controls. The practical selectors only update their EMA statistics from active rows. Actual speedup would require structured partial backward/custom kernels. Example runs ------------ Smoke test: python3 sparse_transformer_v7.py --quick No-warmup discovery test: python3 sparse_transformer_v7.py --steps 1000 \ --active_fractions 0.10 0.05 0.02 \ --policies predicted_magnitude ucb_magnitude oracle_current random \ --warmup_steps_list 0 5 50 --explore_fractions 0.10 0.30 Warm-start separation test: python3 sparse_transformer_v7.py --steps 1000 \ --active_fractions 0.10 0.05 0.02 \ --policies predicted_magnitude ucb_magnitude oracle_current random \ --warmup_steps_list 0 5 50 200 --explore_fractions 0.10 """ from __future__ import annotations import argparse import math import random from typing import Dict, List, Literal, Optional, Tuple import torch torch.set_num_threads(1) import torch.nn as nn import torch.nn.functional as F Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"] def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Data # ----------------------------- def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: rng = random.Random(seed) names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"] verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"] objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"] adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"] clauses = [ "when the loss falls", "after the mask shifts", "before the model answers", "while the signal drifts", "if the pattern repeats", "because the tail is noisy", ] symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"] lines: List[str] = [] for _ in range(n_sentences): t = rng.randrange(6) if t == 0: line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 1: line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}." elif t == 2: a, b = rng.sample(symbols, 2) line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}." elif t == 3: line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}." elif t == 4: seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7))) line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}." else: line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait." lines.append(line) return "\n".join(lines) + "\n" class CharCorpus: def __init__(self, text: str, block_size: int, device: str): chars = sorted(set(text)) self.stoi = {ch: i for i, ch in enumerate(chars)} self.itos = {i: ch for ch, i in self.stoi.items()} self.vocab_size = len(chars) self.block_size = block_size self.device = device data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) split = int(0.9 * len(data)) self.train_data = data[:split] self.val_data = data[split:] def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: data = self.train_data if split == "train" else self.val_data max_start = len(data) - self.block_size - 1 if max_start <= 0: raise ValueError("Corpus too small for block_size") ix = torch.randint(max_start, (batch_size,)) x = torch.stack([data[i : i + self.block_size] for i in ix]) y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) return x.to(self.device), y.to(self.device) def load_text(args: argparse.Namespace) -> str: if args.text_path: with open(args.text_path, "r", encoding="utf-8") as f: return f.read() return make_synthetic_corpus(args.synthetic_sentences, args.seed) # ----------------------------- # Mini GPT # ----------------------------- class CausalSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.c_attn = nn.Linear(n_embd, 3 * n_embd) self.c_proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.c_fc = nn.Linear(n_embd, 4 * n_embd) self.c_proj = nn.Linear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = FeedForward(n_embd, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class MiniGPT(nn.Module): def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.drop = nn.Dropout(dropout) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] x = self.drop(x) x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]: return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)] def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]: total = sum(p.numel() for p in model.parameters()) linear = 0 for _, m in named_linear_modules(model): linear += m.weight.numel() if m.bias is not None: linear += m.bias.numel() return total, linear, linear / max(1, total) # ----------------------------- # Mask selector # ----------------------------- class RowMasker: def __init__( self, model: nn.Module, policy: Policy, active_fraction: float, explore_fraction: float, mass_beta: float, unobserved_decay: float, warmup_steps: int, ucb_alpha: float, mass_init: float, device: str, ): self.model = model self.policy = policy self.active_fraction = active_fraction self.explore_fraction = explore_fraction self.mass_beta = mass_beta self.unobserved_decay = unobserved_decay self.warmup_steps = warmup_steps self.ucb_alpha = ucb_alpha self.mass_init = mass_init self.device = device self.step_index = 0 self.linear_modules = [m for _, m in named_linear_modules(model)] self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {} ids = [] offset = 0 for m in self.linear_modules: n = m.weight.shape[0] block_ids = torch.arange(offset, offset + n, device=device) self.module_to_ids[m] = block_ids ids.append(block_ids) offset += n self.n_blocks = offset self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device) self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device) self.observed_count = torch.zeros(self.n_blocks, device=device) self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device) self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) self.row_masks: Dict[nn.Linear, torch.Tensor] = { m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules } def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor: k = max(1, int(fraction * values.numel())) mask = torch.zeros_like(values, dtype=torch.bool) # Tie-breaking noise matters when many rows have identical initial scores. noisy = values + 1e-9 * torch.rand_like(values) mask[torch.topk(noisy, k=k).indices] = True return mask @staticmethod def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: inter = (a & b).sum().float() union = (a | b).sum().float() return float((inter / torch.clamp(union, min=1.0)).item()) def _set_active(self, active: torch.Tensor) -> None: self.active = active self.row_masks = {} for m, ids in self.module_to_ids.items(): self.row_masks[m] = active[ids] def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor: n = self.n_blocks k_total = max(1, int(self.active_fraction * n)) k_explore = min(k_total, max(0, int(self.explore_fraction * k_total))) k_exploit = k_total - k_explore active = torch.zeros(n, dtype=torch.bool, device=self.device) if k_exploit > 0: active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True if k_explore > 0: remaining = torch.nonzero(~active, as_tuple=False).flatten() pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]] active[pick] = True return active def choose_pre_backward(self, step: int) -> None: self.step_index = step if step < self.warmup_steps: self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)) return if self.policy == "oracle_current": # Cannot select until after current gradients are known. self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)) return if self.policy == "random": self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device))) return if self.policy == "stale_current": self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction)) return if self.policy == "predicted_magnitude": self._set_active(self._sample_exploit_explore(self.predicted_mass)) return if self.policy == "ucb_magnitude": t = max(1, step - self.warmup_steps + 1) log_term = torch.log(torch.tensor(float(t + 2), device=self.device)) bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8) bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0)) scores = self.predicted_mass + bonus self._set_active(self._sample_exploit_explore(scores)) return raise ValueError(f"Unknown policy: {self.policy}") @torch.no_grad() def current_gradient_mass(self) -> torch.Tensor: mass = torch.zeros(self.n_blocks, device=self.device) for m, ids in self.module_to_ids.items(): if m.weight.grad is None: continue row_sq = m.weight.grad.square().sum(dim=1) if m.bias is not None and m.bias.grad is not None: row_sq = row_sq + m.bias.grad.square() mass[ids] = torch.sqrt(row_sq + 1e-30) return mass @torch.no_grad() def audit_and_update(self, step: int) -> Dict[str, float]: mass = self.current_gradient_mass() if step < self.warmup_steps: active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device) self._set_active(active) elif self.policy == "oracle_current": active = self._topk_mask(mass, self.active_fraction) self._set_active(active) else: active = self.active true_sq = mass.square().sum() approx_sq = mass[active].square().sum() cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) norm_ratio = cosine oracle_mask = self._topk_mask(mass, self.active_fraction) jacc = self._jaccard(active, oracle_mask) stability = self._jaccard(active, self.prev_active) self.prev_active = active.clone() k20 = max(1, int(0.2 * self.n_blocks)) sorted_mass = torch.sort(mass, descending=True).values top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) new_active = active & (self.observed_count == 0) # Strict rule for practical policies: update stats only from active rows. # oracle_current and stale_current also update only active rows for consistency; # stale_current separately records last_full_mass as a diagnostic signal. self.predicted_mass.mul_(self.unobserved_decay) observed = active if bool(observed.any().item()): obs_mass = mass[observed] first_seen = self.observed_count[observed] == 0 ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass # First observation should establish the real scale immediately. # Otherwise a beta=0.95 EMA needs many observations to climb from zero. self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass) self.observed_count[observed] += 1.0 self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean() # Dense audit signal. Only stale_current is allowed to use this for next-step selection. self.last_full_mass = mass.detach().clone() coverage = float((self.observed_count > 0).float().mean().item()) avg_obs_count = float(self.observed_count.mean().item()) new_active_fraction = float((new_active.float().mean()).item()) return { "cosine": cosine, "norm_ratio": norm_ratio, "top20_mass": top20_mass, "jacc_oracle": jacc, "stability": stability, "active_fraction_real": float(active.float().mean().item()), "coverage": coverage, "avg_obs_count": avg_obs_count, "new_active_fraction": new_active_fraction, } def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]: return self.row_masks.get(module) # ----------------------------- # Masked Adam # ----------------------------- class MaskedAdam: def __init__( self, model: nn.Module, masker: Optional[RowMasker], lr: float, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0, freeze_non_linear_when_sparse: bool = False, ): self.model = model self.masker = masker self.lr = lr self.beta1, self.beta2 = betas self.eps = eps self.weight_decay = weight_decay self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {} self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {} for _, m in named_linear_modules(model): self.linear_param[m.weight] = (m, "weight") if m.bias is not None: self.linear_param[m.bias] = (m, "bias") def zero_grad(self) -> None: for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self) -> None: for p in self.model.parameters(): if p.grad is None: continue if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param: # Optional stricter mode: freeze embeddings/layernorm/etc. in sparse runs. continue if p not in self.state: self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} m = self.state[p]["m"] v = self.state[p]["v"] g = p.grad if self.weight_decay: g = g.add(p, alpha=self.weight_decay) row_mask = None if self.masker is not None and p in self.linear_param: module, kind = self.linear_param[p] base = self.masker.row_mask_for(module) if base is not None: row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base if row_mask is None: m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1) v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2) p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr) else: mask = row_mask.expand_as(p) if not bool(mask.any().item()): continue new_m = self.beta1 * m + (1.0 - self.beta1) * g new_v = self.beta2 * v + (1.0 - self.beta2) * g * g m[mask] = new_m[mask] v[mask] = new_v[mask] update = m / (torch.sqrt(v) + self.eps) p[mask] = p[mask] - self.lr * update[mask] # ----------------------------- # Training # ----------------------------- @torch.no_grad() def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]: model.eval() out = {} for split in ["train", "val"]: losses = [] for _ in range(eval_iters): x, y = corpus.get_batch(split, batch_size) _, loss = model(x, y) losses.append(float(loss.item())) out[split] = sum(losses) / len(losses) model.train() return out def train_run( corpus: CharCorpus, args: argparse.Namespace, policy: Optional[Policy], active_fraction: float, warmup_steps: int, explore_fraction: float, seed_offset: int, ) -> Dict[str, float | str]: set_seed(args.seed + seed_offset) dev = corpus.device model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) masker = None if policy is not None: masker = RowMasker( model=model, policy=policy, active_fraction=active_fraction, explore_fraction=explore_fraction, mass_beta=args.mass_beta, unobserved_decay=args.unobserved_decay, warmup_steps=warmup_steps, ucb_alpha=args.ucb_alpha, mass_init=args.mass_init, device=dev, ) opt = MaskedAdam( model, masker, lr=args.lr, weight_decay=args.weight_decay, freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse, ) sums = { "cosine": 0.0, "norm_ratio": 0.0, "top20_mass": 0.0, "jacc_oracle": 0.0, "stability": 0.0, "active_fraction_real": 0.0, "coverage": 0.0, "avg_obs_count": 0.0, "new_active_fraction": 0.0, } count = 0 for step in range(args.steps): x, y = corpus.get_batch("train", args.batch_size) if masker is not None: masker.choose_pre_backward(step) _, loss = model(x, y) opt.zero_grad() loss.backward() if masker is not None: metrics = masker.audit_and_update(step) if step >= warmup_steps: for k in sums: sums[k] += metrics[k] count += 1 opt.step() if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1): losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) name = "dense" if policy is None else policy print( f"{name:20s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} " f"train={losses['train']:.4f} val={losses['val']:.4f}" ) losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) row: Dict[str, float | str] = { "run": "dense_baseline" if policy is None else policy, "target_active": 1.0 if policy is None else active_fraction, "warmup": warmup_steps, "explore": explore_fraction if policy is not None else 0.0, "train_loss": losses["train"], "val_loss": losses["val"], } if masker is None or count == 0: row.update({ "cosine": float("nan"), "norm_ratio": float("nan"), "top20_mass": float("nan"), "jacc_oracle": float("nan"), "stability": float("nan"), "active_fraction_real": 1.0, "coverage": float("nan"), "avg_obs_count": float("nan"), "new_active_fraction": float("nan"), }) else: for k, v in sums.items(): row[k] = v / count return row def print_summary(rows: List[Dict[str, float | str]]) -> None: print("\nSummary") header = ( f"{'run':>22s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} " f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} " f"{'stable':>7s} {'cover':>7s} {'new':>7s}" ) print(header) print("-" * len(header)) for r in rows: print( f"{str(r['run']):>22s} " f"{float(r['target_active']):7.3f} " f"{float(r['active_fraction_real']):7.3f} " f"{int(float(r['warmup'])):5d} " f"{float(r['explore']):5.2f} " f"{float(r['val_loss']):8.4f} " f"{float(r['train_loss']):8.4f} " f"{float(r['cosine']):7.3f} " f"{float(r['top20_mass']):7.3f} " f"{float(r['jacc_oracle']):7.3f} " f"{float(r['stability']):7.3f} " f"{float(r['coverage']):7.3f} " f"{float(r['new_active_fraction']):7.3f}" ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--text_path", type=str, default=None) p.add_argument("--synthetic_sentences", type=int, default=12000) p.add_argument("--steps", type=int, default=1000) p.add_argument("--quick", action="store_true") p.add_argument("--batch_size", type=int, default=32) p.add_argument("--block_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=2) p.add_argument("--n_head", type=int, default=4) p.add_argument("--n_embd", type=int, default=64) p.add_argument("--dropout", type=float, default=0.0) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--weight_decay", type=float, default=0.0) p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02]) p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"]) p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.10]) p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5]) p.add_argument("--mass_beta", type=float, default=0.95) p.add_argument("--unobserved_decay", type=float, default=1.0) p.add_argument("--mass_init", type=float, default=0.0) p.add_argument("--ucb_alpha", type=float, default=1.0) p.add_argument("--freeze_non_linear_when_sparse", action="store_true") p.add_argument("--eval_interval", type=int, default=200) p.add_argument("--eval_iters", type=int, default=20) p.add_argument("--seed", type=int, default=7) p.add_argument("--verbose", action="store_true") return p.parse_args() def main() -> None: args = parse_args() if args.quick: args.steps = 60 args.eval_iters = 3 args.batch_size = 16 args.block_size = 32 args.n_layer = 1 args.n_embd = 32 args.n_head = 4 args.synthetic_sentences = 2000 args.active_fractions = [0.10, 0.02] args.policies = ["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"] args.explore_fractions = [0.10] args.warmup_steps_list = [0] # Validate policy strings early. valid = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"} for pol in args.policies: if pol not in valid: raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid)}") set_seed(args.seed) dev = device() print(f"device={dev}") corpus = CharCorpus(load_text(args), args.block_size, dev) print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}") print(f"policies={args.policies}") print(f"active_fractions={args.active_fractions}") print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}") print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}") # Report how much of the model is governed by row masks. tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) total_params, linear_params, linear_frac = parameter_fractions(tmp_model) del tmp_model print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}") if args.freeze_non_linear_when_sparse: print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs") else: print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely") rows: List[Dict[str, float | str]] = [] print("\nRunning dense baseline") rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, warmup_steps=0, explore_fraction=0.0, seed_offset=0)) seed_offset = 100 for af in args.active_fractions: for pol in args.policies: # oracle_current and stale_current do not use explore_fraction; random does not either. explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0] # Warmup matters for every sparse policy, so keep it in the loop. for warmup in args.warmup_steps_list: for explore in explore_values: print(f"\nRunning policy={pol}, active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}") rows.append( train_run( corpus, args, policy=pol, # type: ignore[arg-type] active_fraction=af, warmup_steps=warmup, explore_fraction=explore, seed_offset=seed_offset, ) ) seed_offset += 1 print_summary(rows) print("\nNotes") print(" oracle_current uses current dense gradients to choose rows; it is the true upper bound.") print(" stale_current uses previous-step dense gradient mass; it is a renamed stale/noisy control.") print(" predicted_magnitude uses only EMA mass from active/observed rows.") print(" ucb_magnitude adds an uncertainty bonus for under-observed rows to improve discovery.") print(" coverage is the fraction of Linear rows that have ever been observed/active.") print(" new is the average fraction of rows newly observed per non-warmup step.") print(" dense gradients are still computed for audit; this is not a wall-clock benchmark yet.") if __name__ == "__main__": main()