Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | """ | |
| Sparse Transformer v7: discovery stress tests for stable gradient-support masks. | |
| This version follows the v6 result: | |
| oracle_current works far better than random, so useful sparse support exists; | |
| predicted_magnitude without warmup does not reliably discover that support. | |
| v7 focuses on discovery mechanisms: | |
| 1. predicted_magnitude | |
| Exploit rows with the largest EMA-observed gradient mass. | |
| 2. ucb_magnitude | |
| A bandit-style selector: EMA mass + an uncertainty bonus for under-observed rows. | |
| This is meant to discover useful rows without dense refresh. | |
| First observation initializes EMA scale immediately. | |
| 3. stale_current | |
| A renamed diagnostic control: use the previous full-gradient mass. It is not | |
| practical because it relies on dense audit gradients, but it tells us whether | |
| one-step lag is too noisy. | |
| 4. oracle_current | |
| True current top-k by dense gradient mass. Upper bound only. | |
| 5. random | |
| Control. | |
| Important limitation | |
| -------------------- | |
| This still calls loss.backward(), so PyTorch computes dense gradients. Dense | |
| current gradients are used for audit metrics and for oracle/stale controls. | |
| The practical selectors only update their EMA statistics from active rows. | |
| Actual speedup would require structured partial backward/custom kernels. | |
| Example runs | |
| ------------ | |
| Smoke test: | |
| python3 sparse_transformer_v7.py --quick | |
| No-warmup discovery test: | |
| python3 sparse_transformer_v7.py --steps 1000 \ | |
| --active_fractions 0.10 0.05 0.02 \ | |
| --policies predicted_magnitude ucb_magnitude oracle_current random \ | |
| --warmup_steps_list 0 5 50 --explore_fractions 0.10 0.30 | |
| Warm-start separation test: | |
| python3 sparse_transformer_v7.py --steps 1000 \ | |
| --active_fractions 0.10 0.05 0.02 \ | |
| --policies predicted_magnitude ucb_magnitude oracle_current random \ | |
| --warmup_steps_list 0 5 50 200 --explore_fractions 0.10 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import random | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| import torch | |
| torch.set_num_threads(1) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"] | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # Data | |
| # ----------------------------- | |
| def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: | |
| rng = random.Random(seed) | |
| names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"] | |
| verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"] | |
| objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"] | |
| adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"] | |
| clauses = [ | |
| "when the loss falls", | |
| "after the mask shifts", | |
| "before the model answers", | |
| "while the signal drifts", | |
| "if the pattern repeats", | |
| "because the tail is noisy", | |
| ] | |
| symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"] | |
| lines: List[str] = [] | |
| for _ in range(n_sentences): | |
| t = rng.randrange(6) | |
| if t == 0: | |
| line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}." | |
| elif t == 1: | |
| line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}." | |
| elif t == 2: | |
| a, b = rng.sample(symbols, 2) | |
| line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}." | |
| elif t == 3: | |
| line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}." | |
| elif t == 4: | |
| seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7))) | |
| line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}." | |
| else: | |
| line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait." | |
| lines.append(line) | |
| return "\n".join(lines) + "\n" | |
| class CharCorpus: | |
| def __init__(self, text: str, block_size: int, device: str): | |
| chars = sorted(set(text)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for ch, i in self.stoi.items()} | |
| self.vocab_size = len(chars) | |
| self.block_size = block_size | |
| self.device = device | |
| data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) | |
| split = int(0.9 * len(data)) | |
| self.train_data = data[:split] | |
| self.val_data = data[split:] | |
| def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| data = self.train_data if split == "train" else self.val_data | |
| max_start = len(data) - self.block_size - 1 | |
| if max_start <= 0: | |
| raise ValueError("Corpus too small for block_size") | |
| ix = torch.randint(max_start, (batch_size,)) | |
| x = torch.stack([data[i : i + self.block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) | |
| return x.to(self.device), y.to(self.device) | |
| def load_text(args: argparse.Namespace) -> str: | |
| if args.text_path: | |
| with open(args.text_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| return make_synthetic_corpus(args.synthetic_sentences, args.seed) | |
| # ----------------------------- | |
| # Mini GPT | |
| # ----------------------------- | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = nn.Linear(n_embd, 3 * n_embd) | |
| self.c_proj = nn.Linear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.c_fc = nn.Linear(n_embd, 4 * n_embd) | |
| self.c_proj = nn.Linear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = FeedForward(n_embd, dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class MiniGPT(nn.Module): | |
| def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.drop = nn.Dropout(dropout) | |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.shape | |
| pos = torch.arange(T, device=idx.device) | |
| x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] | |
| x = self.drop(x) | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]: | |
| return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)] | |
| def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]: | |
| total = sum(p.numel() for p in model.parameters()) | |
| linear = 0 | |
| for _, m in named_linear_modules(model): | |
| linear += m.weight.numel() | |
| if m.bias is not None: | |
| linear += m.bias.numel() | |
| return total, linear, linear / max(1, total) | |
| # ----------------------------- | |
| # Mask selector | |
| # ----------------------------- | |
| class RowMasker: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| policy: Policy, | |
| active_fraction: float, | |
| explore_fraction: float, | |
| mass_beta: float, | |
| unobserved_decay: float, | |
| warmup_steps: int, | |
| ucb_alpha: float, | |
| mass_init: float, | |
| device: str, | |
| ): | |
| self.model = model | |
| self.policy = policy | |
| self.active_fraction = active_fraction | |
| self.explore_fraction = explore_fraction | |
| self.mass_beta = mass_beta | |
| self.unobserved_decay = unobserved_decay | |
| self.warmup_steps = warmup_steps | |
| self.ucb_alpha = ucb_alpha | |
| self.mass_init = mass_init | |
| self.device = device | |
| self.step_index = 0 | |
| self.linear_modules = [m for _, m in named_linear_modules(model)] | |
| self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {} | |
| ids = [] | |
| offset = 0 | |
| for m in self.linear_modules: | |
| n = m.weight.shape[0] | |
| block_ids = torch.arange(offset, offset + n, device=device) | |
| self.module_to_ids[m] = block_ids | |
| ids.append(block_ids) | |
| offset += n | |
| self.n_blocks = offset | |
| self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device) | |
| self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device) | |
| self.observed_count = torch.zeros(self.n_blocks, device=device) | |
| self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device) | |
| self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) | |
| self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) | |
| self.row_masks: Dict[nn.Linear, torch.Tensor] = { | |
| m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules | |
| } | |
| def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor: | |
| k = max(1, int(fraction * values.numel())) | |
| mask = torch.zeros_like(values, dtype=torch.bool) | |
| # Tie-breaking noise matters when many rows have identical initial scores. | |
| noisy = values + 1e-9 * torch.rand_like(values) | |
| mask[torch.topk(noisy, k=k).indices] = True | |
| return mask | |
| def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: | |
| inter = (a & b).sum().float() | |
| union = (a | b).sum().float() | |
| return float((inter / torch.clamp(union, min=1.0)).item()) | |
| def _set_active(self, active: torch.Tensor) -> None: | |
| self.active = active | |
| self.row_masks = {} | |
| for m, ids in self.module_to_ids.items(): | |
| self.row_masks[m] = active[ids] | |
| def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor: | |
| n = self.n_blocks | |
| k_total = max(1, int(self.active_fraction * n)) | |
| k_explore = min(k_total, max(0, int(self.explore_fraction * k_total))) | |
| k_exploit = k_total - k_explore | |
| active = torch.zeros(n, dtype=torch.bool, device=self.device) | |
| if k_exploit > 0: | |
| active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True | |
| if k_explore > 0: | |
| remaining = torch.nonzero(~active, as_tuple=False).flatten() | |
| pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]] | |
| active[pick] = True | |
| return active | |
| def choose_pre_backward(self, step: int) -> None: | |
| self.step_index = step | |
| if step < self.warmup_steps: | |
| self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)) | |
| return | |
| if self.policy == "oracle_current": | |
| # Cannot select until after current gradients are known. | |
| self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)) | |
| return | |
| if self.policy == "random": | |
| self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device))) | |
| return | |
| if self.policy == "stale_current": | |
| self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction)) | |
| return | |
| if self.policy == "predicted_magnitude": | |
| self._set_active(self._sample_exploit_explore(self.predicted_mass)) | |
| return | |
| if self.policy == "ucb_magnitude": | |
| t = max(1, step - self.warmup_steps + 1) | |
| log_term = torch.log(torch.tensor(float(t + 2), device=self.device)) | |
| bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8) | |
| bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0)) | |
| scores = self.predicted_mass + bonus | |
| self._set_active(self._sample_exploit_explore(scores)) | |
| return | |
| raise ValueError(f"Unknown policy: {self.policy}") | |
| def current_gradient_mass(self) -> torch.Tensor: | |
| mass = torch.zeros(self.n_blocks, device=self.device) | |
| for m, ids in self.module_to_ids.items(): | |
| if m.weight.grad is None: | |
| continue | |
| row_sq = m.weight.grad.square().sum(dim=1) | |
| if m.bias is not None and m.bias.grad is not None: | |
| row_sq = row_sq + m.bias.grad.square() | |
| mass[ids] = torch.sqrt(row_sq + 1e-30) | |
| return mass | |
| def audit_and_update(self, step: int) -> Dict[str, float]: | |
| mass = self.current_gradient_mass() | |
| if step < self.warmup_steps: | |
| active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device) | |
| self._set_active(active) | |
| elif self.policy == "oracle_current": | |
| active = self._topk_mask(mass, self.active_fraction) | |
| self._set_active(active) | |
| else: | |
| active = self.active | |
| true_sq = mass.square().sum() | |
| approx_sq = mass[active].square().sum() | |
| cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) | |
| norm_ratio = cosine | |
| oracle_mask = self._topk_mask(mass, self.active_fraction) | |
| jacc = self._jaccard(active, oracle_mask) | |
| stability = self._jaccard(active, self.prev_active) | |
| self.prev_active = active.clone() | |
| k20 = max(1, int(0.2 * self.n_blocks)) | |
| sorted_mass = torch.sort(mass, descending=True).values | |
| top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) | |
| new_active = active & (self.observed_count == 0) | |
| # Strict rule for practical policies: update stats only from active rows. | |
| # oracle_current and stale_current also update only active rows for consistency; | |
| # stale_current separately records last_full_mass as a diagnostic signal. | |
| self.predicted_mass.mul_(self.unobserved_decay) | |
| observed = active | |
| if bool(observed.any().item()): | |
| obs_mass = mass[observed] | |
| first_seen = self.observed_count[observed] == 0 | |
| ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass | |
| # First observation should establish the real scale immediately. | |
| # Otherwise a beta=0.95 EMA needs many observations to climb from zero. | |
| self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass) | |
| self.observed_count[observed] += 1.0 | |
| self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean() | |
| # Dense audit signal. Only stale_current is allowed to use this for next-step selection. | |
| self.last_full_mass = mass.detach().clone() | |
| coverage = float((self.observed_count > 0).float().mean().item()) | |
| avg_obs_count = float(self.observed_count.mean().item()) | |
| new_active_fraction = float((new_active.float().mean()).item()) | |
| return { | |
| "cosine": cosine, | |
| "norm_ratio": norm_ratio, | |
| "top20_mass": top20_mass, | |
| "jacc_oracle": jacc, | |
| "stability": stability, | |
| "active_fraction_real": float(active.float().mean().item()), | |
| "coverage": coverage, | |
| "avg_obs_count": avg_obs_count, | |
| "new_active_fraction": new_active_fraction, | |
| } | |
| def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]: | |
| return self.row_masks.get(module) | |
| # ----------------------------- | |
| # Masked Adam | |
| # ----------------------------- | |
| class MaskedAdam: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| masker: Optional[RowMasker], | |
| lr: float, | |
| betas=(0.9, 0.95), | |
| eps=1e-8, | |
| weight_decay=0.0, | |
| freeze_non_linear_when_sparse: bool = False, | |
| ): | |
| self.model = model | |
| self.masker = masker | |
| self.lr = lr | |
| self.beta1, self.beta2 = betas | |
| self.eps = eps | |
| self.weight_decay = weight_decay | |
| self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse | |
| self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {} | |
| self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {} | |
| for _, m in named_linear_modules(model): | |
| self.linear_param[m.weight] = (m, "weight") | |
| if m.bias is not None: | |
| self.linear_param[m.bias] = (m, "bias") | |
| def zero_grad(self) -> None: | |
| for p in self.model.parameters(): | |
| p.grad = None | |
| def step(self) -> None: | |
| for p in self.model.parameters(): | |
| if p.grad is None: | |
| continue | |
| if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param: | |
| # Optional stricter mode: freeze embeddings/layernorm/etc. in sparse runs. | |
| continue | |
| if p not in self.state: | |
| self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} | |
| m = self.state[p]["m"] | |
| v = self.state[p]["v"] | |
| g = p.grad | |
| if self.weight_decay: | |
| g = g.add(p, alpha=self.weight_decay) | |
| row_mask = None | |
| if self.masker is not None and p in self.linear_param: | |
| module, kind = self.linear_param[p] | |
| base = self.masker.row_mask_for(module) | |
| if base is not None: | |
| row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base | |
| if row_mask is None: | |
| m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1) | |
| v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2) | |
| p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr) | |
| else: | |
| mask = row_mask.expand_as(p) | |
| if not bool(mask.any().item()): | |
| continue | |
| new_m = self.beta1 * m + (1.0 - self.beta1) * g | |
| new_v = self.beta2 * v + (1.0 - self.beta2) * g * g | |
| m[mask] = new_m[mask] | |
| v[mask] = new_v[mask] | |
| update = m / (torch.sqrt(v) + self.eps) | |
| p[mask] = p[mask] - self.lr * update[mask] | |
| # ----------------------------- | |
| # Training | |
| # ----------------------------- | |
| def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]: | |
| model.eval() | |
| out = {} | |
| for split in ["train", "val"]: | |
| losses = [] | |
| for _ in range(eval_iters): | |
| x, y = corpus.get_batch(split, batch_size) | |
| _, loss = model(x, y) | |
| losses.append(float(loss.item())) | |
| out[split] = sum(losses) / len(losses) | |
| model.train() | |
| return out | |
| def train_run( | |
| corpus: CharCorpus, | |
| args: argparse.Namespace, | |
| policy: Optional[Policy], | |
| active_fraction: float, | |
| warmup_steps: int, | |
| explore_fraction: float, | |
| seed_offset: int, | |
| ) -> Dict[str, float | str]: | |
| set_seed(args.seed + seed_offset) | |
| dev = corpus.device | |
| model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) | |
| masker = None | |
| if policy is not None: | |
| masker = RowMasker( | |
| model=model, | |
| policy=policy, | |
| active_fraction=active_fraction, | |
| explore_fraction=explore_fraction, | |
| mass_beta=args.mass_beta, | |
| unobserved_decay=args.unobserved_decay, | |
| warmup_steps=warmup_steps, | |
| ucb_alpha=args.ucb_alpha, | |
| mass_init=args.mass_init, | |
| device=dev, | |
| ) | |
| opt = MaskedAdam( | |
| model, | |
| masker, | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse, | |
| ) | |
| sums = { | |
| "cosine": 0.0, | |
| "norm_ratio": 0.0, | |
| "top20_mass": 0.0, | |
| "jacc_oracle": 0.0, | |
| "stability": 0.0, | |
| "active_fraction_real": 0.0, | |
| "coverage": 0.0, | |
| "avg_obs_count": 0.0, | |
| "new_active_fraction": 0.0, | |
| } | |
| count = 0 | |
| for step in range(args.steps): | |
| x, y = corpus.get_batch("train", args.batch_size) | |
| if masker is not None: | |
| masker.choose_pre_backward(step) | |
| _, loss = model(x, y) | |
| opt.zero_grad() | |
| loss.backward() | |
| if masker is not None: | |
| metrics = masker.audit_and_update(step) | |
| if step >= warmup_steps: | |
| for k in sums: | |
| sums[k] += metrics[k] | |
| count += 1 | |
| opt.step() | |
| if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1): | |
| losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) | |
| name = "dense" if policy is None else policy | |
| print( | |
| f"{name:20s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} " | |
| f"train={losses['train']:.4f} val={losses['val']:.4f}" | |
| ) | |
| losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters) | |
| row: Dict[str, float | str] = { | |
| "run": "dense_baseline" if policy is None else policy, | |
| "target_active": 1.0 if policy is None else active_fraction, | |
| "warmup": warmup_steps, | |
| "explore": explore_fraction if policy is not None else 0.0, | |
| "train_loss": losses["train"], | |
| "val_loss": losses["val"], | |
| } | |
| if masker is None or count == 0: | |
| row.update({ | |
| "cosine": float("nan"), | |
| "norm_ratio": float("nan"), | |
| "top20_mass": float("nan"), | |
| "jacc_oracle": float("nan"), | |
| "stability": float("nan"), | |
| "active_fraction_real": 1.0, | |
| "coverage": float("nan"), | |
| "avg_obs_count": float("nan"), | |
| "new_active_fraction": float("nan"), | |
| }) | |
| else: | |
| for k, v in sums.items(): | |
| row[k] = v / count | |
| return row | |
| def print_summary(rows: List[Dict[str, float | str]]) -> None: | |
| print("\nSummary") | |
| header = ( | |
| f"{'run':>22s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} " | |
| f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} " | |
| f"{'stable':>7s} {'cover':>7s} {'new':>7s}" | |
| ) | |
| print(header) | |
| print("-" * len(header)) | |
| for r in rows: | |
| print( | |
| f"{str(r['run']):>22s} " | |
| f"{float(r['target_active']):7.3f} " | |
| f"{float(r['active_fraction_real']):7.3f} " | |
| f"{int(float(r['warmup'])):5d} " | |
| f"{float(r['explore']):5.2f} " | |
| f"{float(r['val_loss']):8.4f} " | |
| f"{float(r['train_loss']):8.4f} " | |
| f"{float(r['cosine']):7.3f} " | |
| f"{float(r['top20_mass']):7.3f} " | |
| f"{float(r['jacc_oracle']):7.3f} " | |
| f"{float(r['stability']):7.3f} " | |
| f"{float(r['coverage']):7.3f} " | |
| f"{float(r['new_active_fraction']):7.3f}" | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--text_path", type=str, default=None) | |
| p.add_argument("--synthetic_sentences", type=int, default=12000) | |
| p.add_argument("--steps", type=int, default=1000) | |
| p.add_argument("--quick", action="store_true") | |
| p.add_argument("--batch_size", type=int, default=32) | |
| p.add_argument("--block_size", type=int, default=64) | |
| p.add_argument("--n_layer", type=int, default=2) | |
| p.add_argument("--n_head", type=int, default=4) | |
| p.add_argument("--n_embd", type=int, default=64) | |
| p.add_argument("--dropout", type=float, default=0.0) | |
| p.add_argument("--lr", type=float, default=3e-4) | |
| p.add_argument("--weight_decay", type=float, default=0.0) | |
| p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02]) | |
| p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"]) | |
| p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.10]) | |
| p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5]) | |
| p.add_argument("--mass_beta", type=float, default=0.95) | |
| p.add_argument("--unobserved_decay", type=float, default=1.0) | |
| p.add_argument("--mass_init", type=float, default=0.0) | |
| p.add_argument("--ucb_alpha", type=float, default=1.0) | |
| p.add_argument("--freeze_non_linear_when_sparse", action="store_true") | |
| p.add_argument("--eval_interval", type=int, default=200) | |
| p.add_argument("--eval_iters", type=int, default=20) | |
| p.add_argument("--seed", type=int, default=7) | |
| p.add_argument("--verbose", action="store_true") | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| if args.quick: | |
| args.steps = 60 | |
| args.eval_iters = 3 | |
| args.batch_size = 16 | |
| args.block_size = 32 | |
| args.n_layer = 1 | |
| args.n_embd = 32 | |
| args.n_head = 4 | |
| args.synthetic_sentences = 2000 | |
| args.active_fractions = [0.10, 0.02] | |
| args.policies = ["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"] | |
| args.explore_fractions = [0.10] | |
| args.warmup_steps_list = [0] | |
| # Validate policy strings early. | |
| valid = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"} | |
| for pol in args.policies: | |
| if pol not in valid: | |
| raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid)}") | |
| set_seed(args.seed) | |
| dev = device() | |
| print(f"device={dev}") | |
| corpus = CharCorpus(load_text(args), args.block_size, dev) | |
| print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}") | |
| print(f"policies={args.policies}") | |
| print(f"active_fractions={args.active_fractions}") | |
| print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}") | |
| print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}") | |
| # Report how much of the model is governed by row masks. | |
| tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) | |
| total_params, linear_params, linear_frac = parameter_fractions(tmp_model) | |
| del tmp_model | |
| print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}") | |
| if args.freeze_non_linear_when_sparse: | |
| print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs") | |
| else: | |
| print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely") | |
| rows: List[Dict[str, float | str]] = [] | |
| print("\nRunning dense baseline") | |
| rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, warmup_steps=0, explore_fraction=0.0, seed_offset=0)) | |
| seed_offset = 100 | |
| for af in args.active_fractions: | |
| for pol in args.policies: | |
| # oracle_current and stale_current do not use explore_fraction; random does not either. | |
| explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0] | |
| # Warmup matters for every sparse policy, so keep it in the loop. | |
| for warmup in args.warmup_steps_list: | |
| for explore in explore_values: | |
| print(f"\nRunning policy={pol}, active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}") | |
| rows.append( | |
| train_run( | |
| corpus, | |
| args, | |
| policy=pol, # type: ignore[arg-type] | |
| active_fraction=af, | |
| warmup_steps=warmup, | |
| explore_fraction=explore, | |
| seed_offset=seed_offset, | |
| ) | |
| ) | |
| seed_offset += 1 | |
| print_summary(rows) | |
| print("\nNotes") | |
| print(" oracle_current uses current dense gradients to choose rows; it is the true upper bound.") | |
| print(" stale_current uses previous-step dense gradient mass; it is a renamed stale/noisy control.") | |
| print(" predicted_magnitude uses only EMA mass from active/observed rows.") | |
| print(" ucb_magnitude adds an uncertainty bonus for under-observed rows to improve discovery.") | |
| print(" coverage is the fraction of Linear rows that have ever been observed/active.") | |
| print(" new is the average fraction of rows newly observed per non-warmup step.") | |
| print(" dense gradients are still computed for audit; this is not a wall-clock benchmark yet.") | |
| if __name__ == "__main__": | |
| main() | |