""" Sparse Transformer v15: Inactive-Update Prediction Diagnostics. Tests two simple ideas: 1. Correlated-neighbor prediction: Use active chunks as sensors. For each inactive chunk, find historically correlated active chunks and predict its update magnitude from them. 2. Graph / boundary interpolation: Treat chunks as nodes in a learned similarity graph. Active chunks are boundary values. Inactive chunk magnitudes are filled in by diffusion. This is intentionally a diagnostic script, not a speed benchmark. It computes dense gradients every step so we can measure whether inactive updates are predictable. Run: python3 sparse_transformer_v15_inactive_prediction.py --device mps --benchmark_sync Good first runs: python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 512 python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 1024 """ import argparse import math import random import time from typing import Dict, List, Literal, Optional, Tuple import torch torch.set_num_threads(1) import torch.nn as nn import torch.nn.functional as F Policy = Literal["predicted_magnitude", "random"] def sync_device(device: str) -> None: if device == "cuda" and torch.cuda.is_available(): torch.cuda.synchronize() elif device == "mps" and hasattr(torch, "mps"): torch.mps.synchronize() def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) def make_cpu_generator(seed: int) -> torch.Generator: gen = torch.Generator(device="cpu") gen.manual_seed(seed) return gen # ----------------------------- # Data # ----------------------------- def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: rng = random.Random(seed) words = [ "ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits", "features", "boldly", "strangely", "matrix", "attention", "kernel", "entropy", "signal", ] return "\n".join( " ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences) ) class CharCorpus: def __init__(self, text: str, block_size: int, device: str): chars = sorted(set(text)) self.stoi = {ch: i for i, ch in enumerate(chars)} self.itos = {i: ch for ch, i in self.stoi.items()} self.vocab_size = len(chars) self.block_size = block_size self.device = device data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) self.train_data = data[: int(0.9 * len(data))] self.val_data = data[int(0.9 * len(data)) :] def get_batch( self, split: str, batch_size: int, generator: Optional[torch.Generator] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: data = self.train_data if split == "train" else self.val_data ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator) x = torch.stack([data[i : i + self.block_size] for i in ix]) y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) return x.to(self.device), y.to(self.device) # ----------------------------- # Model # ----------------------------- class SparseLinear(nn.Linear): """Name retained for compatibility with earlier experiments. In this diagnostic script, backward is dense. We only use chunk masks analytically after gradients are computed. """ class CausalSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.c_attn = SparseLinear(n_embd, 3 * n_embd) self.c_proj = SparseLinear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) self.register_buffer( "mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.c_fc = SparseLinear(n_embd, 4 * n_embd) self.c_proj = SparseLinear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = FeedForward(n_embd, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class MiniGPT(nn.Module): def __init__( self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float, ): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential( *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)] ) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def get_sparse_linears(model: nn.Module) -> List[SparseLinear]: return [m for m in model.modules() if isinstance(m, SparseLinear)] # ----------------------------- # Chunk geometry and diagnostics # ----------------------------- class ChunkMap: def __init__(self, model: nn.Module, chunk_size: int, device: str): self.model = model self.chunk_size = chunk_size self.device = device self.linears = get_sparse_linears(model) self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {} self.chunk_to_module_local: List[Tuple[nn.Module, int]] = [] offset = 0 for m in self.linears: assert m.out_features % chunk_size == 0, ( f"out_features {m.out_features} not divisible by chunk_size {chunk_size}" ) n_chunks = m.out_features // chunk_size ids = torch.arange(offset, offset + n_chunks, device=device) self.module_to_chunk_ids[m] = ids for local_c in range(n_chunks): self.chunk_to_module_local.append((m, local_c)) offset += n_chunks self.n_chunks = offset self.predicted_mass = torch.zeros(self.n_chunks, device=device) self.direction_ema: List[Optional[torch.Tensor]] = [None for _ in range(self.n_chunks)] # Histories for correlation and graph similarities. self.mass_history: List[torch.Tensor] = [] def choose_active( self, step: int, warmup_steps: int, active_fraction: float, policy: Policy, ) -> torch.Tensor: if step < warmup_steps: return torch.ones(self.n_chunks, dtype=torch.bool, device=self.device) k = max(1, int(active_fraction * self.n_chunks)) mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) if policy == "random": idx = torch.randperm(self.n_chunks, device=self.device)[:k] else: scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass) idx = torch.topk(scores, k=k).indices mask[idx] = True return mask @torch.no_grad() def chunk_gradient_vectors(self) -> List[torch.Tensor]: vecs: List[torch.Tensor] = [] for m, local_c in self.chunk_to_module_local: start = local_c * self.chunk_size end = (local_c + 1) * self.chunk_size parts = [] if m.weight.grad is None: parts.append(torch.zeros_like(m.weight[start:end]).flatten()) else: parts.append(m.weight.grad[start:end].detach().flatten()) if m.bias is not None: if m.bias.grad is None: parts.append(torch.zeros_like(m.bias[start:end]).flatten()) else: parts.append(m.bias.grad[start:end].detach().flatten()) vecs.append(torch.cat(parts)) return vecs @torch.no_grad() def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor: return torch.stack([v.norm() for v in vecs]).to(self.device) @torch.no_grad() def update_predictor( self, active_mask: torch.Tensor, vecs: List[torch.Tensor], mass_beta: float = 0.95, dir_beta: float = 0.95, store_history: bool = True, ) -> torch.Tensor: masses = self.chunk_masses_from_vecs(vecs) observed = active_mask # First observation should initialize directly, not get shrunk by beta. never_seen = observed & (self.predicted_mass == 0) already_seen = observed & ~never_seen self.predicted_mass[never_seen] = masses[never_seen] self.predicted_mass[already_seen] = ( mass_beta * self.predicted_mass[already_seen] + (1.0 - mass_beta) * masses[already_seen] ) for i, is_active in enumerate(observed.tolist()): if not is_active: continue v = vecs[i] n = v.norm() if n <= 1e-12: continue unit = v / n if self.direction_ema[i] is None: self.direction_ema[i] = unit.detach().clone() else: self.direction_ema[i] = ( dir_beta * self.direction_ema[i] + (1.0 - dir_beta) * unit ) self.direction_ema[i] = self.direction_ema[i] / (self.direction_ema[i].norm() + 1e-12) if store_history: self.mass_history.append(masses.detach().clone()) max_hist = 128 if len(self.mass_history) > max_hist: self.mass_history = self.mass_history[-max_hist:] return masses def layer_aware_masks(self) -> List[torch.Tensor]: masks = [] for m, ids in self.module_to_chunk_ids.items(): mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) mask[ids] = True masks.append(mask) return masks def dense_cosine_from_vecs(a: List[torch.Tensor], b: List[torch.Tensor]) -> float: va = torch.cat([x.flatten() for x in a]) vb = torch.cat([x.flatten() for x in b]) return float(F.cosine_similarity(va, vb, dim=0).item()) def mse_reduction_vs_zero(true_vecs: List[torch.Tensor], pred_vecs: List[torch.Tensor], mask: torch.Tensor) -> float: idxs = torch.nonzero(mask, as_tuple=False).flatten().tolist() if not idxs: return float("nan") true = torch.cat([true_vecs[i].flatten() for i in idxs]) pred = torch.cat([pred_vecs[i].flatten() for i in idxs]) zero_mse = torch.mean(true.square()) pred_mse = torch.mean((true - pred).square()) return float((1.0 - pred_mse / (zero_mse + 1e-12)).item()) def active_only_prediction(true_vecs: List[torch.Tensor], active_mask: torch.Tensor) -> List[torch.Tensor]: out = [] for i, v in enumerate(true_vecs): out.append(v.clone() if bool(active_mask[i]) else torch.zeros_like(v)) return out def ema_direction_prediction( cmap: ChunkMap, true_vecs: List[torch.Tensor], active_mask: torch.Tensor, inactive_magnitudes: torch.Tensor, ) -> List[torch.Tensor]: out = [] for i, v in enumerate(true_vecs): if bool(active_mask[i]): out.append(v.clone()) else: direction = cmap.direction_ema[i] if direction is None: out.append(torch.zeros_like(v)) else: out.append(direction.to(v.device, v.dtype) * inactive_magnitudes[i]) return out def build_mass_similarity(cmap: ChunkMap, min_history: int = 8) -> Optional[torch.Tensor]: if len(cmap.mass_history) < min_history: return None H = torch.stack(cmap.mass_history, dim=0) # [history, chunks] H = H - H.mean(dim=0, keepdim=True) H = H / (H.std(dim=0, keepdim=True) + 1e-6) S = (H.T @ H) / max(1, H.shape[0] - 1) S = torch.clamp(S, min=0.0) # Remove self similarity. S.fill_diagonal_(0.0) # Layer-aware block diagonal: avoid mixing unrelated layers by default. layer_masks = cmap.layer_aware_masks() layer_allowed = torch.zeros_like(S, dtype=torch.bool) for mask in layer_masks: layer_allowed |= mask[:, None] & mask[None, :] S = torch.where(layer_allowed, S, torch.zeros_like(S)) return S def knn_magnitude_prediction( cmap: ChunkMap, active_mask: torch.Tensor, true_masses: torch.Tensor, k_neighbors: int = 3, ) -> torch.Tensor: """Predict inactive magnitudes as weighted average of correlated active magnitudes.""" S = build_mass_similarity(cmap) if S is None: pred = cmap.predicted_mass.clone() pred[active_mask] = true_masses[active_mask] return pred pred = torch.zeros_like(true_masses) pred[active_mask] = true_masses[active_mask] active_idx = torch.nonzero(active_mask, as_tuple=False).flatten() inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten() if active_idx.numel() == 0: return pred for i in inactive_idx.tolist(): weights = S[i, active_idx] if weights.sum() <= 1e-12: pred[i] = cmap.predicted_mass[i] continue kk = min(k_neighbors, weights.numel()) top = torch.topk(weights, k=kk) w = top.values aidx = active_idx[top.indices] pred[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12) return pred def graph_diffusion_magnitude_prediction( cmap: ChunkMap, active_mask: torch.Tensor, true_masses: torch.Tensor, diffusion_steps: int = 8, alpha: float = 0.7, ) -> torch.Tensor: """Boundary-value style magnitude interpolation over a learned similarity graph. Active nodes are clamped to observed true magnitudes. Inactive nodes diffuse toward graph-neighbor values. """ S = build_mass_similarity(cmap) if S is None: pred = cmap.predicted_mass.clone() pred[active_mask] = true_masses[active_mask] return pred W = S / (S.sum(dim=1, keepdim=True) + 1e-12) pred = cmap.predicted_mass.clone() pred[active_mask] = true_masses[active_mask] for _ in range(diffusion_steps): proposal = W @ pred pred = alpha * proposal + (1.0 - alpha) * pred pred[active_mask] = true_masses[active_mask] return torch.clamp(pred, min=0.0) # ----------------------------- # Optimizer # ----------------------------- class SimpleAdam: """Small Adam-like optimizer for diagnostics. This is intentionally simple and consistent across runs. It is not trying to be production AdamW. """ def __init__(self, model: nn.Module, lr: float = 3e-4): self.model = model self.lr = lr self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {} def zero_grad(self): for p in self.model.parameters(): p.grad = None @torch.no_grad() def step(self): for p in self.model.parameters(): if p.grad is None: continue if p not in self.state: self.state[p] = { "m": torch.zeros_like(p), "v": torch.zeros_like(p), } m = self.state[p]["m"] v = self.state[p]["v"] m.mul_(0.9).add_(p.grad, alpha=0.1) v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr) # ----------------------------- # Apply chunk-gradient predictions # ----------------------------- @torch.no_grad() def install_chunk_prediction_as_grads( cmap: ChunkMap, pred_vecs: List[torch.Tensor], ): """Overwrite SparseLinear weight/bias grads from predicted chunk vectors. Non-SparseLinear parameters keep their dense gradients. """ for m, ids in cmap.module_to_chunk_ids.items(): if m.weight.grad is None: continue m.weight.grad.zero_() if m.bias is not None and m.bias.grad is not None: m.bias.grad.zero_() for local_c, global_id in enumerate(ids.tolist()): start = local_c * cmap.chunk_size end = (local_c + 1) * cmap.chunk_size v = pred_vecs[global_id] w_numel = cmap.chunk_size * m.weight.shape[1] w_flat = v[:w_numel] m.weight.grad[start:end] = w_flat.view(cmap.chunk_size, m.weight.shape[1]) if m.bias is not None and m.bias.grad is not None: b_flat = v[w_numel:] if b_flat.numel() > 0: m.bias.grad[start:end] = b_flat.view(cmap.chunk_size) # ----------------------------- # Training / diagnostics # ----------------------------- def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float: model.eval() with torch.no_grad(): x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed)) _, loss = model(x, y) model.train() return float(loss.item()) def run_experiment( mode: str, device: str, steps: int, batch_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, chunk_size: int, active_fraction: float, warmup_steps: int, policy: Policy, benchmark_sync: bool, ) -> Dict[str, float]: set_seed(42) corpus = CharCorpus(make_synthetic_corpus(), block_size, device) model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device) opt = SimpleAdam(model, lr=3e-4) cmap = ChunkMap(model, chunk_size=chunk_size, device=device) metric_rows = [] if benchmark_sync: sync_device(device) t0 = time.perf_counter() for step in range(steps): x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step)) opt.zero_grad() _, loss = model(x, y) loss.backward() true_vecs = cmap.chunk_gradient_vectors() true_masses = cmap.chunk_masses_from_vecs(true_vecs) active_mask = cmap.choose_active( step=step, warmup_steps=warmup_steps, active_fraction=active_fraction, policy=policy, ) if step < warmup_steps or mode == "dense": pred_vecs = [v.clone() for v in true_vecs] else: active_only_vecs = active_only_prediction(true_vecs, active_mask) if mode == "active_only": pred_vecs = active_only_vecs elif mode == "knn_magnitude": pred_masses = knn_magnitude_prediction(cmap, active_mask, true_masses) pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) elif mode == "graph_diffusion": pred_masses = graph_diffusion_magnitude_prediction(cmap, active_mask, true_masses) pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) elif mode == "ema_inactive": pred_masses = cmap.predicted_mass.clone() pred_masses[active_mask] = true_masses[active_mask] pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) else: raise ValueError(f"Unknown mode: {mode}") install_chunk_prediction_as_grads(cmap, pred_vecs) if step % 25 == 0: inactive_mask = ~active_mask row = { "cosine_full": dense_cosine_from_vecs(true_vecs, pred_vecs), "inactive_mse_reduction": mse_reduction_vs_zero(true_vecs, pred_vecs, inactive_mask), "active_frac": float(active_mask.float().mean().item()), "val": evaluate(model, corpus, batch_size, seed=999 + step), } metric_rows.append(row) # Update predictor after measuring and installing predicted grads. # Use true active chunk observations only, mimicking sparse observation. cmap.update_predictor(active_mask, true_vecs, store_history=True) opt.step() if benchmark_sync: sync_device(device) elapsed = time.perf_counter() - t0 val_loss = evaluate(model, corpus, batch_size, seed=12345) if metric_rows: avg_cos = sum(r["cosine_full"] for r in metric_rows) / len(metric_rows) avg_mse_red = sum(r["inactive_mse_reduction"] for r in metric_rows) / len(metric_rows) else: avg_cos = float("nan") avg_mse_red = float("nan") return { "val": val_loss, "ms": 1000.0 * elapsed / steps, "cos": avg_cos, "mse_red": avg_mse_red, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=300) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--block_size", type=int, default=128) parser.add_argument("--n_layer", type=int, default=4) parser.add_argument("--n_head", type=int, default=8) parser.add_argument("--n_embd", type=int, default=512) parser.add_argument("--chunk_size", type=int, default=64) parser.add_argument("--active_fraction", type=float, default=0.10) parser.add_argument("--warmup_steps", type=int, default=25) parser.add_argument("--policy", type=str, default="predicted_magnitude", choices=["predicted_magnitude", "random"]) parser.add_argument("--device", type=str, default="mps") parser.add_argument("--benchmark_sync", action="store_true") args = parser.parse_args() modes = [ "dense", "active_only", "ema_inactive", "knn_magnitude", "graph_diffusion", ] print(f"\nInactive-update prediction diagnostic") print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}") print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps} policy={args.policy}\n") print(f"{'mode':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'inactive_mse+':>13s}") print("-" * 70) for mode in modes: result = run_experiment( mode=mode, device=args.device, steps=args.steps, batch_size=args.batch_size, block_size=args.block_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, chunk_size=args.chunk_size, active_fraction=args.active_fraction, warmup_steps=args.warmup_steps, policy=args.policy, benchmark_sync=args.benchmark_sync, ) print( f"{mode:>18s} | " f"{result['val']:8.4f} | " f"{result['ms']:8.2f} | " f"{result['cos']:8.3f} | " f"{result['mse_red']:13.3f}" ) if __name__ == "__main__": main()