Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | """ | |
| Sparse Transformer v9: no-audit sparse training after dense warmup. | |
| v8 proved that the row-sparse mask can be moved into a custom Linear backward. | |
| v9 removes the remaining dense-audit crutch. | |
| Default behavior | |
| ---------------- | |
| 1. Run a short dense warmup, usually 5 steps. | |
| 2. Initialize the EMA row-importance predictor from those dense warmup gradients. | |
| 3. After warmup, choose active rows from the predictor. | |
| 4. Train using sparse backward. | |
| 5. Update EMA statistics only from rows that were actually active/observed. | |
| 6. Do not compute dense gradients unless --audit_every > 0. | |
| Audit behavior | |
| -------------- | |
| --audit_every 0 | |
| No dense audit after warmup. Cosine/Jaccard/top20 are unavailable and show as nan. | |
| --audit_every N | |
| Every N steps, run an extra dense backward pass on the same batch only to | |
| measure cosine/top20/Jaccard. The audit is NOT used to update the selector, | |
| except for oracle_current, which is explicitly an upper-bound control. | |
| This is still not a wall-clock benchmark on vanilla PyTorch/MPS/CPU. The custom | |
| backward uses indexing and ordinary PyTorch matmuls. The goal is to verify that | |
| the method survives without dense information after warmup. | |
| Examples | |
| -------- | |
| No-audit practical run: | |
| python3 sparse_transformer_v9.py \ | |
| --device mps \ | |
| --steps 2000 \ | |
| --active_fractions 0.05 0.02 \ | |
| --warmup_steps_list 5 \ | |
| --policies predicted_magnitude random \ | |
| --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \ | |
| --audit_every 0 | |
| Occasional audit for measurement only: | |
| python3 sparse_transformer_v9.py \ | |
| --steps 2000 \ | |
| --active_fractions 0.05 0.02 \ | |
| --warmup_steps_list 5 \ | |
| --policies predicted_magnitude random \ | |
| --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \ | |
| --audit_every 100 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import random | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| import torch | |
| torch.set_num_threads(1) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"] | |
| BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] | |
| # ----------------------------- | |
| # Reproducibility and device | |
| # ----------------------------- | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def default_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def make_cpu_generator(seed: int) -> torch.Generator: | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed(seed) | |
| return gen | |
| # ----------------------------- | |
| # Data | |
| # ----------------------------- | |
| def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: | |
| rng = random.Random(seed) | |
| names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"] | |
| verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"] | |
| objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"] | |
| adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"] | |
| clauses = [ | |
| "when the loss falls", | |
| "after the mask shifts", | |
| "before the model answers", | |
| "while the signal drifts", | |
| "if the pattern repeats", | |
| "because the tail is noisy", | |
| ] | |
| symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"] | |
| lines: List[str] = [] | |
| for _ in range(n_sentences): | |
| t = rng.randrange(6) | |
| if t == 0: | |
| line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}." | |
| elif t == 1: | |
| line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}." | |
| elif t == 2: | |
| a, b = rng.sample(symbols, 2) | |
| line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}." | |
| elif t == 3: | |
| line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}." | |
| elif t == 4: | |
| seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7))) | |
| line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}." | |
| else: | |
| line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait." | |
| lines.append(line) | |
| return "\n".join(lines) + "\n" | |
| class CharCorpus: | |
| def __init__(self, text: str, block_size: int, device: str): | |
| chars = sorted(set(text)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for ch, i in self.stoi.items()} | |
| self.vocab_size = len(chars) | |
| self.block_size = block_size | |
| self.device = device | |
| data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) | |
| split = int(0.9 * len(data)) | |
| self.train_data = data[:split] | |
| self.val_data = data[split:] | |
| def get_batch( | |
| self, | |
| split: str, | |
| batch_size: int, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| data = self.train_data if split == "train" else self.val_data | |
| max_start = len(data) - self.block_size - 1 | |
| if max_start <= 0: | |
| raise ValueError("Corpus too small for block_size") | |
| ix = torch.randint(max_start, (batch_size,), generator=generator) | |
| x = torch.stack([data[i : i + self.block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) | |
| return x.to(self.device), y.to(self.device) | |
| def load_text(args: argparse.Namespace) -> str: | |
| if args.text_path: | |
| with open(args.text_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| return make_synthetic_corpus(args.synthetic_sentences, args.seed) | |
| # ----------------------------- | |
| # Sparse Linear autograd | |
| # ----------------------------- | |
| class MaskedLinearFunction(torch.autograd.Function): | |
| def forward( # type: ignore[override] | |
| ctx, | |
| x: torch.Tensor, | |
| weight: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| active_rows: torch.Tensor, | |
| sparse_dx: bool, | |
| ) -> torch.Tensor: | |
| ctx.save_for_backward(x, weight, active_rows) | |
| ctx.has_bias = bias is not None | |
| ctx.sparse_dx = bool(sparse_dx) | |
| return F.linear(x, weight, bias) | |
| def backward(ctx, grad_y: torch.Tensor): # type: ignore[override] | |
| x, weight, active_rows = ctx.saved_tensors | |
| sparse_dx = bool(ctx.sparse_dx) | |
| has_bias = bool(ctx.has_bias) | |
| x_shape = x.shape | |
| x_flat = x.reshape(-1, x.shape[-1]) | |
| gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) | |
| active_idx = torch.nonzero(active_rows, as_tuple=False).flatten() | |
| grad_weight = torch.zeros_like(weight) | |
| grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None | |
| if active_idx.numel() > 0: | |
| gy_active = gy_flat[:, active_idx] | |
| grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat | |
| if grad_bias is not None: | |
| grad_bias[active_idx] = gy_active.sum(dim=0) | |
| if sparse_dx: | |
| grad_x_flat = gy_active @ weight[active_idx] | |
| else: | |
| grad_x_flat = gy_flat @ weight | |
| else: | |
| # This can happen when a global top-k mask selects no rows from a | |
| # particular layer. Conservative full_dX still propagates through that | |
| # layer; aggressive sparse_dX cuts it off for that layer. | |
| if sparse_dx: | |
| grad_x_flat = torch.zeros_like(x_flat) | |
| else: | |
| grad_x_flat = gy_flat @ weight | |
| grad_x = grad_x_flat.reshape(x_shape) | |
| return grad_x, grad_weight, grad_bias, None, None | |
| class SparseLinear(nn.Linear): | |
| """nn.Linear with an optional row-sparse backward pass.""" | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__(in_features, out_features, bias=bias) | |
| self.sparse_enabled = False | |
| self.sparse_dx = False | |
| self.active_rows: Optional[torch.Tensor] = None | |
| def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None: | |
| self.sparse_enabled = bool(enabled) | |
| self.sparse_dx = bool(sparse_dx) | |
| self.active_rows = active_rows | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if not self.sparse_enabled or self.active_rows is None: | |
| return F.linear(x, self.weight, self.bias) | |
| return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx) | |
| # ----------------------------- | |
| # Mini GPT | |
| # ----------------------------- | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = SparseLinear(n_embd, 3 * n_embd) | |
| self.c_proj = SparseLinear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.c_fc = SparseLinear(n_embd, 4 * n_embd) | |
| self.c_proj = SparseLinear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = FeedForward(n_embd, dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class MiniGPT(nn.Module): | |
| def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.drop = nn.Dropout(dropout) | |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = SparseLinear(n_embd, vocab_size) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.shape | |
| pos = torch.arange(T, device=idx.device) | |
| x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] | |
| x = self.drop(x) | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]: | |
| return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)] | |
| def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]: | |
| total = sum(p.numel() for p in model.parameters()) | |
| linear = 0 | |
| for _, m in named_sparse_linear_modules(model): | |
| linear += m.weight.numel() | |
| if m.bias is not None: | |
| linear += m.bias.numel() | |
| return total, linear, linear / max(1, total) | |
| def configure_sparse_linears( | |
| model: nn.Module, | |
| masker: Optional["RowMasker"], | |
| enabled: bool, | |
| backward_mode: Optional[str], | |
| ) -> None: | |
| sparse_dx = backward_mode == "sparse_dW_sparse_dX" | |
| for _, m in named_sparse_linear_modules(model): | |
| active = masker.row_mask_for(m) if masker is not None else None | |
| m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx) | |
| # ----------------------------- | |
| # Mask selector | |
| # ----------------------------- | |
| class RowMasker: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| policy: Policy, | |
| active_fraction: float, | |
| explore_fraction: float, | |
| mass_beta: float, | |
| unobserved_decay: float, | |
| warmup_steps: int, | |
| ucb_alpha: float, | |
| mass_init: float, | |
| device: str, | |
| ): | |
| self.model = model | |
| self.policy = policy | |
| self.active_fraction = active_fraction | |
| self.explore_fraction = explore_fraction | |
| self.mass_beta = mass_beta | |
| self.unobserved_decay = unobserved_decay | |
| self.warmup_steps = warmup_steps | |
| self.ucb_alpha = ucb_alpha | |
| self.mass_init = mass_init | |
| self.device = device | |
| self.step_index = 0 | |
| self.linear_modules = [m for _, m in named_sparse_linear_modules(model)] | |
| self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {} | |
| ids = [] | |
| offset = 0 | |
| for m in self.linear_modules: | |
| n = m.weight.shape[0] | |
| block_ids = torch.arange(offset, offset + n, device=device) | |
| self.module_to_ids[m] = block_ids | |
| ids.append(block_ids) | |
| offset += n | |
| self.n_blocks = offset | |
| self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device) | |
| self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device) | |
| self.observed_count = torch.zeros(self.n_blocks, device=device) | |
| self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device) | |
| self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) | |
| self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device) | |
| self.row_masks: Dict[SparseLinear, torch.Tensor] = { | |
| m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules | |
| } | |
| def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor: | |
| k = max(1, int(fraction * values.numel())) | |
| mask = torch.zeros_like(values, dtype=torch.bool) | |
| noisy = values + 1e-9 * torch.rand_like(values) | |
| mask[torch.topk(noisy, k=k).indices] = True | |
| return mask | |
| def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: | |
| inter = (a & b).sum().float() | |
| union = (a | b).sum().float() | |
| return float((inter / torch.clamp(union, min=1.0)).item()) | |
| def _set_active(self, active: torch.Tensor) -> None: | |
| self.active = active | |
| self.row_masks = {} | |
| for m, ids in self.module_to_ids.items(): | |
| self.row_masks[m] = active[ids] | |
| def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor: | |
| n = self.n_blocks | |
| k_total = max(1, int(self.active_fraction * n)) | |
| k_explore = min(k_total, max(0, int(self.explore_fraction * k_total))) | |
| k_exploit = k_total - k_explore | |
| active = torch.zeros(n, dtype=torch.bool, device=self.device) | |
| if k_exploit > 0: | |
| active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True | |
| if k_explore > 0: | |
| remaining = torch.nonzero(~active, as_tuple=False).flatten() | |
| pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]] | |
| active[pick] = True | |
| return active | |
| def choose_pre_backward(self, step: int) -> None: | |
| self.step_index = step | |
| if step < self.warmup_steps: | |
| self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)) | |
| return | |
| if self.policy == "oracle_current": | |
| # Oracle cannot choose until the dense audit gradient is known. | |
| self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)) | |
| return | |
| if self.policy == "random": | |
| self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device))) | |
| return | |
| if self.policy == "stale_current": | |
| self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction)) | |
| return | |
| if self.policy == "predicted_magnitude": | |
| self._set_active(self._sample_exploit_explore(self.predicted_mass)) | |
| return | |
| if self.policy == "ucb_magnitude": | |
| t = max(1, step - self.warmup_steps + 1) | |
| log_term = torch.log(torch.tensor(float(t + 2), device=self.device)) | |
| bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8) | |
| bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0)) | |
| self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus)) | |
| return | |
| raise ValueError(f"Unknown policy: {self.policy}") | |
| def current_gradient_mass_from_grads(self) -> torch.Tensor: | |
| mass = torch.zeros(self.n_blocks, device=self.device) | |
| for m, ids in self.module_to_ids.items(): | |
| if m.weight.grad is None: | |
| continue | |
| row_sq = m.weight.grad.square().sum(dim=1) | |
| if m.bias is not None and m.bias.grad is not None: | |
| row_sq = row_sq + m.bias.grad.square() | |
| mass[ids] = torch.sqrt(row_sq + 1e-30) | |
| return mass | |
| def update_predictor_from_observed_mass(self, mass: torch.Tensor, observed: Optional[torch.Tensor] = None) -> Dict[str, float]: | |
| """Update EMA statistics only for observed rows. | |
| After warmup, sparse backward only gives trustworthy gradients for active | |
| rows, so only those rows are allowed to update predicted_mass. | |
| """ | |
| if observed is None: | |
| observed = self.active | |
| new_active = observed & (self.observed_count == 0) | |
| self.predicted_mass.mul_(self.unobserved_decay) | |
| if bool(observed.any().item()): | |
| obs_mass = mass[observed] | |
| first_seen = self.observed_count[observed] == 0 | |
| ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass | |
| self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass) | |
| self.observed_count[observed] += 1.0 | |
| self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean() | |
| stability = self._jaccard(self.active, self.prev_active) | |
| self.prev_active = self.active.clone() | |
| return { | |
| "stability": stability, | |
| "active_fraction_real": float(self.active.float().mean().item()), | |
| "coverage": float((self.observed_count > 0).float().mean().item()), | |
| "avg_obs_count": float(self.observed_count.mean().item()), | |
| "new_active_fraction": float(new_active.float().mean().item()), | |
| } | |
| def audit_metrics_from_mass(self, mass: torch.Tensor) -> Dict[str, float]: | |
| """Compute dense-audit metrics without updating the practical selector.""" | |
| active = self.active | |
| true_sq = mass.square().sum() | |
| approx_sq = mass[active].square().sum() | |
| cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) | |
| oracle_mask = self._topk_mask(mass, self.active_fraction) | |
| jacc = self._jaccard(active, oracle_mask) | |
| k20 = max(1, int(0.2 * self.n_blocks)) | |
| sorted_mass = torch.sort(mass, descending=True).values | |
| top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) | |
| return { | |
| "cosine": cosine, | |
| "norm_ratio": cosine, | |
| "top20_mass": top20_mass, | |
| "jacc_oracle": jacc, | |
| } | |
| def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]: | |
| if step < self.warmup_steps: | |
| active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device) | |
| self._set_active(active) | |
| elif self.policy == "oracle_current": | |
| active = self._topk_mask(mass, self.active_fraction) | |
| self._set_active(active) | |
| else: | |
| active = self.active | |
| true_sq = mass.square().sum() | |
| approx_sq = mass[active].square().sum() | |
| cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item()) | |
| oracle_mask = self._topk_mask(mass, self.active_fraction) | |
| jacc = self._jaccard(active, oracle_mask) | |
| stability = self._jaccard(active, self.prev_active) | |
| self.prev_active = active.clone() | |
| k20 = max(1, int(0.2 * self.n_blocks)) | |
| sorted_mass = torch.sort(mass, descending=True).values | |
| top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item()) | |
| new_active = active & (self.observed_count == 0) | |
| # Practical rule: update predicted statistics only for active/observed rows. | |
| self.predicted_mass.mul_(self.unobserved_decay) | |
| observed = active | |
| if bool(observed.any().item()): | |
| obs_mass = mass[observed] | |
| first_seen = self.observed_count[observed] == 0 | |
| ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass | |
| self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass) | |
| self.observed_count[observed] += 1.0 | |
| self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean() | |
| # Dense audit signal; only stale_current is allowed to use this for selection. | |
| self.last_full_mass = mass.detach().clone() | |
| return { | |
| "cosine": cosine, | |
| "norm_ratio": cosine, | |
| "top20_mass": top20_mass, | |
| "jacc_oracle": jacc, | |
| "stability": stability, | |
| "active_fraction_real": float(active.float().mean().item()), | |
| "coverage": float((self.observed_count > 0).float().mean().item()), | |
| "avg_obs_count": float(self.observed_count.mean().item()), | |
| "new_active_fraction": float(new_active.float().mean().item()), | |
| } | |
| def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]: | |
| return self.row_masks.get(module) | |
| # ----------------------------- | |
| # Masked Adam | |
| # ----------------------------- | |
| class MaskedAdam: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| masker: Optional[RowMasker], | |
| lr: float, | |
| betas=(0.9, 0.95), | |
| eps=1e-8, | |
| weight_decay=0.0, | |
| freeze_non_linear_when_sparse: bool = False, | |
| ): | |
| self.model = model | |
| self.masker = masker | |
| self.lr = lr | |
| self.beta1, self.beta2 = betas | |
| self.eps = eps | |
| self.weight_decay = weight_decay | |
| self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse | |
| self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {} | |
| self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {} | |
| for _, m in named_sparse_linear_modules(model): | |
| self.linear_param[m.weight] = (m, "weight") | |
| if m.bias is not None: | |
| self.linear_param[m.bias] = (m, "bias") | |
| def zero_grad(self) -> None: | |
| for p in self.model.parameters(): | |
| p.grad = None | |
| def step(self) -> None: | |
| for p in self.model.parameters(): | |
| if p.grad is None: | |
| continue | |
| if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param: | |
| continue | |
| if p not in self.state: | |
| self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} | |
| m = self.state[p]["m"] | |
| v = self.state[p]["v"] | |
| g = p.grad | |
| if self.weight_decay: | |
| g = g.add(p, alpha=self.weight_decay) | |
| row_mask = None | |
| if self.masker is not None and p in self.linear_param: | |
| module, kind = self.linear_param[p] | |
| base = self.masker.row_mask_for(module) | |
| if base is not None: | |
| row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base | |
| if row_mask is None: | |
| m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1) | |
| v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2) | |
| p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr) | |
| else: | |
| # MPS can mis-handle expanded boolean masks for row-wise assignment | |
| # (e.g. reporting nonsense out-of-bounds indices). Use explicit | |
| # row indices and index_copy_ instead. This also avoids materializing | |
| # a full expanded mask for weight matrices. | |
| active_rows = row_mask.reshape(-1).nonzero(as_tuple=False).flatten() | |
| if active_rows.numel() == 0: | |
| continue | |
| m_rows = m.index_select(0, active_rows) | |
| v_rows = v.index_select(0, active_rows) | |
| g_rows = g.index_select(0, active_rows) | |
| new_m_rows = self.beta1 * m_rows + (1.0 - self.beta1) * g_rows | |
| new_v_rows = self.beta2 * v_rows + (1.0 - self.beta2) * g_rows * g_rows | |
| update_rows = new_m_rows / (torch.sqrt(new_v_rows) + self.eps) | |
| p_rows = p.index_select(0, active_rows) - self.lr * update_rows | |
| m.index_copy_(0, active_rows, new_m_rows) | |
| v.index_copy_(0, active_rows, new_v_rows) | |
| p.index_copy_(0, active_rows, p_rows) | |
| # ----------------------------- | |
| # Training utilities | |
| # ----------------------------- | |
| def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]: | |
| model.eval() | |
| configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) | |
| out = {} | |
| for split in ["train", "val"]: | |
| losses = [] | |
| gen = make_cpu_generator(seed + (0 if split == "train" else 100000)) | |
| for _ in range(eval_iters): | |
| x, y = corpus.get_batch(split, batch_size, generator=gen) | |
| _, loss = model(x, y) | |
| losses.append(float(loss.item())) | |
| out[split] = sum(losses) / len(losses) | |
| model.train() | |
| return out | |
| def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor: | |
| x, y = corpus_batch | |
| configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) | |
| opt.zero_grad() | |
| _, audit_loss = model(x, y) | |
| audit_loss.backward() | |
| mass = masker.current_gradient_mass_from_grads() | |
| opt.zero_grad() | |
| return mass | |
| def sparse_training_backward( | |
| model: nn.Module, | |
| corpus_batch: Tuple[torch.Tensor, torch.Tensor], | |
| opt: MaskedAdam, | |
| masker: Optional[RowMasker], | |
| backward_mode: Optional[BackwardMode], | |
| ) -> float: | |
| x, y = corpus_batch | |
| opt.zero_grad() | |
| if masker is None or backward_mode is None or backward_mode == "masked_optimizer": | |
| configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) | |
| else: | |
| configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode) | |
| _, loss = model(x, y) | |
| loss.backward() | |
| configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None) | |
| return float(loss.item()) | |
| def train_run( | |
| corpus: CharCorpus, | |
| args: argparse.Namespace, | |
| policy: Optional[Policy], | |
| backward_mode: Optional[BackwardMode], | |
| active_fraction: float, | |
| warmup_steps: int, | |
| explore_fraction: float, | |
| seed_offset: int, | |
| ) -> Dict[str, float | str]: | |
| # Same model initialization and same minibatch sequence for every run by default. | |
| set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0)) | |
| data_gen = make_cpu_generator(args.seed + 12345) | |
| dev = corpus.device | |
| model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) | |
| masker = None | |
| if policy is not None: | |
| masker = RowMasker( | |
| model=model, | |
| policy=policy, | |
| active_fraction=active_fraction, | |
| explore_fraction=explore_fraction, | |
| mass_beta=args.mass_beta, | |
| unobserved_decay=args.unobserved_decay, | |
| warmup_steps=warmup_steps, | |
| ucb_alpha=args.ucb_alpha, | |
| mass_init=args.mass_init, | |
| device=dev, | |
| ) | |
| opt = MaskedAdam( | |
| model, | |
| masker, | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse, | |
| ) | |
| sums = { | |
| "cosine": 0.0, | |
| "norm_ratio": 0.0, | |
| "top20_mass": 0.0, | |
| "jacc_oracle": 0.0, | |
| "stability": 0.0, | |
| "active_fraction_real": 0.0, | |
| "coverage": 0.0, | |
| "avg_obs_count": 0.0, | |
| "new_active_fraction": 0.0, | |
| } | |
| counts = {k: 0 for k in sums} | |
| def add_metrics(metrics: Dict[str, float]) -> None: | |
| for k, v in metrics.items(): | |
| if k in sums: | |
| sums[k] += float(v) | |
| counts[k] += 1 | |
| for step in range(args.steps): | |
| batch = corpus.get_batch("train", args.batch_size, generator=data_gen) | |
| if masker is None: | |
| loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None) | |
| opt.step() | |
| else: | |
| if step < warmup_steps: | |
| # Dense bootstrap. Every row is active and every row updates the predictor. | |
| masker._set_active(torch.ones(masker.n_blocks, dtype=torch.bool, device=dev)) | |
| loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode="masked_optimizer") | |
| full_mass = masker.current_gradient_mass_from_grads() | |
| masker.last_full_mass = full_mass.detach().clone() | |
| add_metrics(masker.audit_metrics_from_mass(full_mass)) | |
| add_metrics(masker.update_predictor_from_observed_mass(full_mass, observed=masker.active)) | |
| opt.step() | |
| else: | |
| masker.choose_pre_backward(step) | |
| if policy == "oracle_current": | |
| # Explicit upper bound. Oracle necessarily computes dense gradients to choose rows. | |
| full_mass = dense_audit_pass(model, batch, opt, masker) | |
| masker._set_active(masker._topk_mask(full_mass, active_fraction)) | |
| masker.last_full_mass = full_mass.detach().clone() | |
| add_metrics(masker.audit_metrics_from_mass(full_mass)) | |
| elif args.audit_every > 0 and ((step - warmup_steps) % args.audit_every == 0): | |
| # Measurement only. Do not update predicted_magnitude/ucb/random with this dense mass. | |
| full_mass = dense_audit_pass(model, batch, opt, masker) | |
| add_metrics(masker.audit_metrics_from_mass(full_mass)) | |
| if policy == "stale_current": | |
| masker.last_full_mass = full_mass.detach().clone() | |
| loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode) | |
| # Practical selector update: only active rows were observed by the training backward pass. | |
| observed_mass = masker.current_gradient_mass_from_grads() | |
| add_metrics(masker.update_predictor_from_observed_mass(observed_mass, observed=masker.active)) | |
| opt.step() | |
| if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1): | |
| losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555) | |
| name = "dense" if policy is None else f"{policy}/{backward_mode}" | |
| print( | |
| f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} " | |
| f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}" | |
| ) | |
| losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999) | |
| row: Dict[str, float | str] = { | |
| "run": "dense_baseline" if policy is None else policy, | |
| "mode": "dense" if backward_mode is None else backward_mode, | |
| "target_active": 1.0 if policy is None else active_fraction, | |
| "warmup": warmup_steps, | |
| "explore": explore_fraction if policy is not None else 0.0, | |
| "train_loss": losses["train"], | |
| "val_loss": losses["val"], | |
| } | |
| if masker is None: | |
| row.update({ | |
| "cosine": float("nan"), | |
| "norm_ratio": float("nan"), | |
| "top20_mass": float("nan"), | |
| "jacc_oracle": float("nan"), | |
| "stability": float("nan"), | |
| "active_fraction_real": 1.0, | |
| "coverage": float("nan"), | |
| "avg_obs_count": float("nan"), | |
| "new_active_fraction": float("nan"), | |
| }) | |
| else: | |
| for k in sums: | |
| row[k] = (sums[k] / counts[k]) if counts[k] > 0 else float("nan") | |
| return row | |
| def print_summary(rows: List[Dict[str, float | str]]) -> None: | |
| print("\nSummary") | |
| header = ( | |
| f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} " | |
| f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} " | |
| f"{'stable':>7s} {'cover':>7s} {'new':>7s}" | |
| ) | |
| print(header) | |
| print("-" * len(header)) | |
| for r in rows: | |
| print( | |
| f"{str(r['run']):>22s} " | |
| f"{str(r['mode']):>19s} " | |
| f"{float(r['target_active']):7.3f} " | |
| f"{float(r['active_fraction_real']):7.3f} " | |
| f"{int(float(r['warmup'])):5d} " | |
| f"{float(r['explore']):5.2f} " | |
| f"{float(r['val_loss']):8.4f} " | |
| f"{float(r['train_loss']):8.4f} " | |
| f"{float(r['cosine']):7.3f} " | |
| f"{float(r['top20_mass']):7.3f} " | |
| f"{float(r['jacc_oracle']):7.3f} " | |
| f"{float(r['stability']):7.3f} " | |
| f"{float(r['coverage']):7.3f} " | |
| f"{float(r['new_active_fraction']):7.3f}" | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--text_path", type=str, default=None) | |
| p.add_argument("--synthetic_sentences", type=int, default=12000) | |
| p.add_argument("--steps", type=int, default=1000) | |
| p.add_argument("--quick", action="store_true") | |
| p.add_argument("--batch_size", type=int, default=32) | |
| p.add_argument("--block_size", type=int, default=64) | |
| p.add_argument("--n_layer", type=int, default=2) | |
| p.add_argument("--n_head", type=int, default=4) | |
| p.add_argument("--n_embd", type=int, default=64) | |
| p.add_argument("--dropout", type=float, default=0.0) | |
| p.add_argument("--lr", type=float, default=3e-4) | |
| p.add_argument("--weight_decay", type=float, default=0.0) | |
| p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02]) | |
| p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"]) | |
| p.add_argument( | |
| "--backward_modes", | |
| type=str, | |
| nargs="+", | |
| default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"], | |
| ) | |
| p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0]) | |
| p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5]) | |
| p.add_argument("--mass_beta", type=float, default=0.95) | |
| p.add_argument("--unobserved_decay", type=float, default=1.0) | |
| p.add_argument("--mass_init", type=float, default=0.0) | |
| p.add_argument("--ucb_alpha", type=float, default=1.0) | |
| p.add_argument("--freeze_non_linear_when_sparse", action="store_true") | |
| p.add_argument("--eval_interval", type=int, default=200) | |
| p.add_argument("--eval_iters", type=int, default=20) | |
| p.add_argument("--seed", type=int, default=7) | |
| p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda", "mps"]) | |
| p.add_argument("--audit_every", type=int, default=0, help="Dense audit interval after warmup. 0 disables audits except oracle_current.") | |
| p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.") | |
| p.add_argument("--verbose", action="store_true") | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| if args.quick: | |
| args.steps = 40 | |
| args.eval_iters = 2 | |
| args.batch_size = 8 | |
| args.block_size = 32 | |
| args.n_layer = 1 | |
| args.n_embd = 32 | |
| args.n_head = 4 | |
| args.synthetic_sentences = 1200 | |
| args.active_fractions = [0.05] | |
| args.policies = ["predicted_magnitude", "random"] | |
| args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"] | |
| args.explore_fractions = [0.0] | |
| args.warmup_steps_list = [5] | |
| args.audit_every = 10 | |
| valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"} | |
| valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"} | |
| for pol in args.policies: | |
| if pol not in valid_policies: | |
| raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}") | |
| for mode in args.backward_modes: | |
| if mode not in valid_modes: | |
| raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}") | |
| set_seed(args.seed) | |
| dev = args.device if args.device != "auto" else default_device() | |
| print(f"device={dev}") | |
| corpus = CharCorpus(load_text(args), args.block_size, dev) | |
| print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}") | |
| print(f"policies={args.policies}") | |
| print(f"backward_modes={args.backward_modes}") | |
| print(f"active_fractions={args.active_fractions}") | |
| print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}") | |
| print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}") | |
| print(f"paired_seeds={not args.unpaired_seeds}") | |
| print(f"audit_every={args.audit_every} (0 means no dense audit after warmup, except oracle_current)") | |
| tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev) | |
| total_params, linear_params, linear_frac = parameter_fractions(tmp_model) | |
| del tmp_model | |
| print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}") | |
| if args.freeze_non_linear_when_sparse: | |
| print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs") | |
| else: | |
| print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely") | |
| if args.dropout != 0.0: | |
| print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks") | |
| rows: List[Dict[str, float | str]] = [] | |
| print("\nRunning dense baseline") | |
| rows.append( | |
| train_run( | |
| corpus, | |
| args, | |
| policy=None, | |
| backward_mode=None, | |
| active_fraction=1.0, | |
| warmup_steps=0, | |
| explore_fraction=0.0, | |
| seed_offset=0, | |
| ) | |
| ) | |
| seed_offset = 100 | |
| for mode in args.backward_modes: | |
| for af in args.active_fractions: | |
| for pol in args.policies: | |
| explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0] | |
| for warmup in args.warmup_steps_list: | |
| for explore in explore_values: | |
| print( | |
| f"\nRunning mode={mode}, policy={pol}, " | |
| f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}" | |
| ) | |
| rows.append( | |
| train_run( | |
| corpus, | |
| args, | |
| policy=pol, # type: ignore[arg-type] | |
| backward_mode=mode, # type: ignore[arg-type] | |
| active_fraction=af, | |
| warmup_steps=warmup, | |
| explore_fraction=explore, | |
| seed_offset=seed_offset, | |
| ) | |
| ) | |
| seed_offset += 1 | |
| print_summary(rows) | |
| print("\nNotes") | |
| print(" masked_optimizer is the v7-style dense-backward simulation control.") | |
| print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.") | |
| print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.") | |
| print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.") | |
| print(" predicted_magnitude uses EMA mass from active/observed rows only.") | |
| print(" random is the sparse-support control.") | |
| print(" v9 does not compute dense audit gradients after warmup unless --audit_every > 0, except oracle_current.") | |
| print(" predicted_magnitude updates EMA statistics only from active rows observed by the training backward pass.") | |
| print(" cosine/top20/jacc are nan when --audit_every 0 because no dense reference gradient is computed.") | |
| print(" This is still not a wall-clock benchmark: PyTorch indexing may not accelerate on CPU/MPS without a custom Metal kernel.") | |
| if __name__ == "__main__": | |
| main() | |