Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | """ | |
| Sparse Transformer v15: Inactive-Update Prediction Diagnostics. | |
| Tests two simple ideas: | |
| 1. Correlated-neighbor prediction: | |
| Use active chunks as sensors. For each inactive chunk, find historically | |
| correlated active chunks and predict its update magnitude from them. | |
| 2. Graph / boundary interpolation: | |
| Treat chunks as nodes in a learned similarity graph. Active chunks are | |
| boundary values. Inactive chunk magnitudes are filled in by diffusion. | |
| This is intentionally a diagnostic script, not a speed benchmark. | |
| It computes dense gradients every step so we can measure whether inactive | |
| updates are predictable. | |
| Run: | |
| python3 sparse_transformer_v15_inactive_prediction.py --device mps --benchmark_sync | |
| Good first runs: | |
| python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 512 | |
| python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 1024 | |
| """ | |
| import argparse | |
| import math | |
| import random | |
| import time | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| import torch | |
| torch.set_num_threads(1) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| Policy = Literal["predicted_magnitude", "random"] | |
| def sync_device(device: str) -> None: | |
| if device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elif device == "mps" and hasattr(torch, "mps"): | |
| torch.mps.synchronize() | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| def make_cpu_generator(seed: int) -> torch.Generator: | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed(seed) | |
| return gen | |
| # ----------------------------- | |
| # Data | |
| # ----------------------------- | |
| def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: | |
| rng = random.Random(seed) | |
| words = [ | |
| "ada", "turing", "grace", "lovelace", "gradients", | |
| "tokens", "circuits", "features", "boldly", "strangely", | |
| "matrix", "attention", "kernel", "entropy", "signal", | |
| ] | |
| return "\n".join( | |
| " ".join(rng.choices(words, k=rng.randint(4, 10))) + "." | |
| for _ in range(n_sentences) | |
| ) | |
| class CharCorpus: | |
| def __init__(self, text: str, block_size: int, device: str): | |
| chars = sorted(set(text)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for ch, i in self.stoi.items()} | |
| self.vocab_size = len(chars) | |
| self.block_size = block_size | |
| self.device = device | |
| data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) | |
| self.train_data = data[: int(0.9 * len(data))] | |
| self.val_data = data[int(0.9 * len(data)) :] | |
| def get_batch( | |
| self, | |
| split: str, | |
| batch_size: int, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| data = self.train_data if split == "train" else self.val_data | |
| ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator) | |
| x = torch.stack([data[i : i + self.block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) | |
| return x.to(self.device), y.to(self.device) | |
| # ----------------------------- | |
| # Model | |
| # ----------------------------- | |
| class SparseLinear(nn.Linear): | |
| """Name retained for compatibility with earlier experiments. | |
| In this diagnostic script, backward is dense. We only use chunk masks | |
| analytically after gradients are computed. | |
| """ | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = SparseLinear(n_embd, 3 * n_embd) | |
| self.c_proj = SparseLinear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer( | |
| "mask", | |
| torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.c_fc = SparseLinear(n_embd, 4 * n_embd) | |
| self.c_proj = SparseLinear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = FeedForward(n_embd, dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class MiniGPT(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| block_size: int, | |
| n_layer: int, | |
| n_head: int, | |
| n_embd: int, | |
| dropout: float, | |
| ): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.Sequential( | |
| *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)] | |
| ) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.shape | |
| pos = torch.arange(T, device=idx.device) | |
| x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def get_sparse_linears(model: nn.Module) -> List[SparseLinear]: | |
| return [m for m in model.modules() if isinstance(m, SparseLinear)] | |
| # ----------------------------- | |
| # Chunk geometry and diagnostics | |
| # ----------------------------- | |
| class ChunkMap: | |
| def __init__(self, model: nn.Module, chunk_size: int, device: str): | |
| self.model = model | |
| self.chunk_size = chunk_size | |
| self.device = device | |
| self.linears = get_sparse_linears(model) | |
| self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {} | |
| self.chunk_to_module_local: List[Tuple[nn.Module, int]] = [] | |
| offset = 0 | |
| for m in self.linears: | |
| assert m.out_features % chunk_size == 0, ( | |
| f"out_features {m.out_features} not divisible by chunk_size {chunk_size}" | |
| ) | |
| n_chunks = m.out_features // chunk_size | |
| ids = torch.arange(offset, offset + n_chunks, device=device) | |
| self.module_to_chunk_ids[m] = ids | |
| for local_c in range(n_chunks): | |
| self.chunk_to_module_local.append((m, local_c)) | |
| offset += n_chunks | |
| self.n_chunks = offset | |
| self.predicted_mass = torch.zeros(self.n_chunks, device=device) | |
| self.direction_ema: List[Optional[torch.Tensor]] = [None for _ in range(self.n_chunks)] | |
| # Histories for correlation and graph similarities. | |
| self.mass_history: List[torch.Tensor] = [] | |
| def choose_active( | |
| self, | |
| step: int, | |
| warmup_steps: int, | |
| active_fraction: float, | |
| policy: Policy, | |
| ) -> torch.Tensor: | |
| if step < warmup_steps: | |
| return torch.ones(self.n_chunks, dtype=torch.bool, device=self.device) | |
| k = max(1, int(active_fraction * self.n_chunks)) | |
| mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) | |
| if policy == "random": | |
| idx = torch.randperm(self.n_chunks, device=self.device)[:k] | |
| else: | |
| scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass) | |
| idx = torch.topk(scores, k=k).indices | |
| mask[idx] = True | |
| return mask | |
| def chunk_gradient_vectors(self) -> List[torch.Tensor]: | |
| vecs: List[torch.Tensor] = [] | |
| for m, local_c in self.chunk_to_module_local: | |
| start = local_c * self.chunk_size | |
| end = (local_c + 1) * self.chunk_size | |
| parts = [] | |
| if m.weight.grad is None: | |
| parts.append(torch.zeros_like(m.weight[start:end]).flatten()) | |
| else: | |
| parts.append(m.weight.grad[start:end].detach().flatten()) | |
| if m.bias is not None: | |
| if m.bias.grad is None: | |
| parts.append(torch.zeros_like(m.bias[start:end]).flatten()) | |
| else: | |
| parts.append(m.bias.grad[start:end].detach().flatten()) | |
| vecs.append(torch.cat(parts)) | |
| return vecs | |
| def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor: | |
| return torch.stack([v.norm() for v in vecs]).to(self.device) | |
| def update_predictor( | |
| self, | |
| active_mask: torch.Tensor, | |
| vecs: List[torch.Tensor], | |
| mass_beta: float = 0.95, | |
| dir_beta: float = 0.95, | |
| store_history: bool = True, | |
| ) -> torch.Tensor: | |
| masses = self.chunk_masses_from_vecs(vecs) | |
| observed = active_mask | |
| # First observation should initialize directly, not get shrunk by beta. | |
| never_seen = observed & (self.predicted_mass == 0) | |
| already_seen = observed & ~never_seen | |
| self.predicted_mass[never_seen] = masses[never_seen] | |
| self.predicted_mass[already_seen] = ( | |
| mass_beta * self.predicted_mass[already_seen] | |
| + (1.0 - mass_beta) * masses[already_seen] | |
| ) | |
| for i, is_active in enumerate(observed.tolist()): | |
| if not is_active: | |
| continue | |
| v = vecs[i] | |
| n = v.norm() | |
| if n <= 1e-12: | |
| continue | |
| unit = v / n | |
| if self.direction_ema[i] is None: | |
| self.direction_ema[i] = unit.detach().clone() | |
| else: | |
| self.direction_ema[i] = ( | |
| dir_beta * self.direction_ema[i] + (1.0 - dir_beta) * unit | |
| ) | |
| self.direction_ema[i] = self.direction_ema[i] / (self.direction_ema[i].norm() + 1e-12) | |
| if store_history: | |
| self.mass_history.append(masses.detach().clone()) | |
| max_hist = 128 | |
| if len(self.mass_history) > max_hist: | |
| self.mass_history = self.mass_history[-max_hist:] | |
| return masses | |
| def layer_aware_masks(self) -> List[torch.Tensor]: | |
| masks = [] | |
| for m, ids in self.module_to_chunk_ids.items(): | |
| mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) | |
| mask[ids] = True | |
| masks.append(mask) | |
| return masks | |
| def dense_cosine_from_vecs(a: List[torch.Tensor], b: List[torch.Tensor]) -> float: | |
| va = torch.cat([x.flatten() for x in a]) | |
| vb = torch.cat([x.flatten() for x in b]) | |
| return float(F.cosine_similarity(va, vb, dim=0).item()) | |
| def mse_reduction_vs_zero(true_vecs: List[torch.Tensor], pred_vecs: List[torch.Tensor], mask: torch.Tensor) -> float: | |
| idxs = torch.nonzero(mask, as_tuple=False).flatten().tolist() | |
| if not idxs: | |
| return float("nan") | |
| true = torch.cat([true_vecs[i].flatten() for i in idxs]) | |
| pred = torch.cat([pred_vecs[i].flatten() for i in idxs]) | |
| zero_mse = torch.mean(true.square()) | |
| pred_mse = torch.mean((true - pred).square()) | |
| return float((1.0 - pred_mse / (zero_mse + 1e-12)).item()) | |
| def active_only_prediction(true_vecs: List[torch.Tensor], active_mask: torch.Tensor) -> List[torch.Tensor]: | |
| out = [] | |
| for i, v in enumerate(true_vecs): | |
| out.append(v.clone() if bool(active_mask[i]) else torch.zeros_like(v)) | |
| return out | |
| def ema_direction_prediction( | |
| cmap: ChunkMap, | |
| true_vecs: List[torch.Tensor], | |
| active_mask: torch.Tensor, | |
| inactive_magnitudes: torch.Tensor, | |
| ) -> List[torch.Tensor]: | |
| out = [] | |
| for i, v in enumerate(true_vecs): | |
| if bool(active_mask[i]): | |
| out.append(v.clone()) | |
| else: | |
| direction = cmap.direction_ema[i] | |
| if direction is None: | |
| out.append(torch.zeros_like(v)) | |
| else: | |
| out.append(direction.to(v.device, v.dtype) * inactive_magnitudes[i]) | |
| return out | |
| def build_mass_similarity(cmap: ChunkMap, min_history: int = 8) -> Optional[torch.Tensor]: | |
| if len(cmap.mass_history) < min_history: | |
| return None | |
| H = torch.stack(cmap.mass_history, dim=0) # [history, chunks] | |
| H = H - H.mean(dim=0, keepdim=True) | |
| H = H / (H.std(dim=0, keepdim=True) + 1e-6) | |
| S = (H.T @ H) / max(1, H.shape[0] - 1) | |
| S = torch.clamp(S, min=0.0) | |
| # Remove self similarity. | |
| S.fill_diagonal_(0.0) | |
| # Layer-aware block diagonal: avoid mixing unrelated layers by default. | |
| layer_masks = cmap.layer_aware_masks() | |
| layer_allowed = torch.zeros_like(S, dtype=torch.bool) | |
| for mask in layer_masks: | |
| layer_allowed |= mask[:, None] & mask[None, :] | |
| S = torch.where(layer_allowed, S, torch.zeros_like(S)) | |
| return S | |
| def knn_magnitude_prediction( | |
| cmap: ChunkMap, | |
| active_mask: torch.Tensor, | |
| true_masses: torch.Tensor, | |
| k_neighbors: int = 3, | |
| ) -> torch.Tensor: | |
| """Predict inactive magnitudes as weighted average of correlated active magnitudes.""" | |
| S = build_mass_similarity(cmap) | |
| if S is None: | |
| pred = cmap.predicted_mass.clone() | |
| pred[active_mask] = true_masses[active_mask] | |
| return pred | |
| pred = torch.zeros_like(true_masses) | |
| pred[active_mask] = true_masses[active_mask] | |
| active_idx = torch.nonzero(active_mask, as_tuple=False).flatten() | |
| inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten() | |
| if active_idx.numel() == 0: | |
| return pred | |
| for i in inactive_idx.tolist(): | |
| weights = S[i, active_idx] | |
| if weights.sum() <= 1e-12: | |
| pred[i] = cmap.predicted_mass[i] | |
| continue | |
| kk = min(k_neighbors, weights.numel()) | |
| top = torch.topk(weights, k=kk) | |
| w = top.values | |
| aidx = active_idx[top.indices] | |
| pred[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12) | |
| return pred | |
| def graph_diffusion_magnitude_prediction( | |
| cmap: ChunkMap, | |
| active_mask: torch.Tensor, | |
| true_masses: torch.Tensor, | |
| diffusion_steps: int = 8, | |
| alpha: float = 0.7, | |
| ) -> torch.Tensor: | |
| """Boundary-value style magnitude interpolation over a learned similarity graph. | |
| Active nodes are clamped to observed true magnitudes. Inactive nodes diffuse | |
| toward graph-neighbor values. | |
| """ | |
| S = build_mass_similarity(cmap) | |
| if S is None: | |
| pred = cmap.predicted_mass.clone() | |
| pred[active_mask] = true_masses[active_mask] | |
| return pred | |
| W = S / (S.sum(dim=1, keepdim=True) + 1e-12) | |
| pred = cmap.predicted_mass.clone() | |
| pred[active_mask] = true_masses[active_mask] | |
| for _ in range(diffusion_steps): | |
| proposal = W @ pred | |
| pred = alpha * proposal + (1.0 - alpha) * pred | |
| pred[active_mask] = true_masses[active_mask] | |
| return torch.clamp(pred, min=0.0) | |
| # ----------------------------- | |
| # Optimizer | |
| # ----------------------------- | |
| class SimpleAdam: | |
| """Small Adam-like optimizer for diagnostics. | |
| This is intentionally simple and consistent across runs. It is not trying | |
| to be production AdamW. | |
| """ | |
| def __init__(self, model: nn.Module, lr: float = 3e-4): | |
| self.model = model | |
| self.lr = lr | |
| self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {} | |
| def zero_grad(self): | |
| for p in self.model.parameters(): | |
| p.grad = None | |
| def step(self): | |
| for p in self.model.parameters(): | |
| if p.grad is None: | |
| continue | |
| if p not in self.state: | |
| self.state[p] = { | |
| "m": torch.zeros_like(p), | |
| "v": torch.zeros_like(p), | |
| } | |
| m = self.state[p]["m"] | |
| v = self.state[p]["v"] | |
| m.mul_(0.9).add_(p.grad, alpha=0.1) | |
| v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) | |
| p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr) | |
| # ----------------------------- | |
| # Apply chunk-gradient predictions | |
| # ----------------------------- | |
| def install_chunk_prediction_as_grads( | |
| cmap: ChunkMap, | |
| pred_vecs: List[torch.Tensor], | |
| ): | |
| """Overwrite SparseLinear weight/bias grads from predicted chunk vectors. | |
| Non-SparseLinear parameters keep their dense gradients. | |
| """ | |
| for m, ids in cmap.module_to_chunk_ids.items(): | |
| if m.weight.grad is None: | |
| continue | |
| m.weight.grad.zero_() | |
| if m.bias is not None and m.bias.grad is not None: | |
| m.bias.grad.zero_() | |
| for local_c, global_id in enumerate(ids.tolist()): | |
| start = local_c * cmap.chunk_size | |
| end = (local_c + 1) * cmap.chunk_size | |
| v = pred_vecs[global_id] | |
| w_numel = cmap.chunk_size * m.weight.shape[1] | |
| w_flat = v[:w_numel] | |
| m.weight.grad[start:end] = w_flat.view(cmap.chunk_size, m.weight.shape[1]) | |
| if m.bias is not None and m.bias.grad is not None: | |
| b_flat = v[w_numel:] | |
| if b_flat.numel() > 0: | |
| m.bias.grad[start:end] = b_flat.view(cmap.chunk_size) | |
| # ----------------------------- | |
| # Training / diagnostics | |
| # ----------------------------- | |
| def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float: | |
| model.eval() | |
| with torch.no_grad(): | |
| x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed)) | |
| _, loss = model(x, y) | |
| model.train() | |
| return float(loss.item()) | |
| def run_experiment( | |
| mode: str, | |
| device: str, | |
| steps: int, | |
| batch_size: int, | |
| block_size: int, | |
| n_layer: int, | |
| n_head: int, | |
| n_embd: int, | |
| chunk_size: int, | |
| active_fraction: float, | |
| warmup_steps: int, | |
| policy: Policy, | |
| benchmark_sync: bool, | |
| ) -> Dict[str, float]: | |
| set_seed(42) | |
| corpus = CharCorpus(make_synthetic_corpus(), block_size, device) | |
| model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device) | |
| opt = SimpleAdam(model, lr=3e-4) | |
| cmap = ChunkMap(model, chunk_size=chunk_size, device=device) | |
| metric_rows = [] | |
| if benchmark_sync: | |
| sync_device(device) | |
| t0 = time.perf_counter() | |
| for step in range(steps): | |
| x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step)) | |
| opt.zero_grad() | |
| _, loss = model(x, y) | |
| loss.backward() | |
| true_vecs = cmap.chunk_gradient_vectors() | |
| true_masses = cmap.chunk_masses_from_vecs(true_vecs) | |
| active_mask = cmap.choose_active( | |
| step=step, | |
| warmup_steps=warmup_steps, | |
| active_fraction=active_fraction, | |
| policy=policy, | |
| ) | |
| if step < warmup_steps or mode == "dense": | |
| pred_vecs = [v.clone() for v in true_vecs] | |
| else: | |
| active_only_vecs = active_only_prediction(true_vecs, active_mask) | |
| if mode == "active_only": | |
| pred_vecs = active_only_vecs | |
| elif mode == "knn_magnitude": | |
| pred_masses = knn_magnitude_prediction(cmap, active_mask, true_masses) | |
| pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) | |
| elif mode == "graph_diffusion": | |
| pred_masses = graph_diffusion_magnitude_prediction(cmap, active_mask, true_masses) | |
| pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) | |
| elif mode == "ema_inactive": | |
| pred_masses = cmap.predicted_mass.clone() | |
| pred_masses[active_mask] = true_masses[active_mask] | |
| pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode}") | |
| install_chunk_prediction_as_grads(cmap, pred_vecs) | |
| if step % 25 == 0: | |
| inactive_mask = ~active_mask | |
| row = { | |
| "cosine_full": dense_cosine_from_vecs(true_vecs, pred_vecs), | |
| "inactive_mse_reduction": mse_reduction_vs_zero(true_vecs, pred_vecs, inactive_mask), | |
| "active_frac": float(active_mask.float().mean().item()), | |
| "val": evaluate(model, corpus, batch_size, seed=999 + step), | |
| } | |
| metric_rows.append(row) | |
| # Update predictor after measuring and installing predicted grads. | |
| # Use true active chunk observations only, mimicking sparse observation. | |
| cmap.update_predictor(active_mask, true_vecs, store_history=True) | |
| opt.step() | |
| if benchmark_sync: | |
| sync_device(device) | |
| elapsed = time.perf_counter() - t0 | |
| val_loss = evaluate(model, corpus, batch_size, seed=12345) | |
| if metric_rows: | |
| avg_cos = sum(r["cosine_full"] for r in metric_rows) / len(metric_rows) | |
| avg_mse_red = sum(r["inactive_mse_reduction"] for r in metric_rows) / len(metric_rows) | |
| else: | |
| avg_cos = float("nan") | |
| avg_mse_red = float("nan") | |
| return { | |
| "val": val_loss, | |
| "ms": 1000.0 * elapsed / steps, | |
| "cos": avg_cos, | |
| "mse_red": avg_mse_red, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--steps", type=int, default=300) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--block_size", type=int, default=128) | |
| parser.add_argument("--n_layer", type=int, default=4) | |
| parser.add_argument("--n_head", type=int, default=8) | |
| parser.add_argument("--n_embd", type=int, default=512) | |
| parser.add_argument("--chunk_size", type=int, default=64) | |
| parser.add_argument("--active_fraction", type=float, default=0.10) | |
| parser.add_argument("--warmup_steps", type=int, default=25) | |
| parser.add_argument("--policy", type=str, default="predicted_magnitude", choices=["predicted_magnitude", "random"]) | |
| parser.add_argument("--device", type=str, default="mps") | |
| parser.add_argument("--benchmark_sync", action="store_true") | |
| args = parser.parse_args() | |
| modes = [ | |
| "dense", | |
| "active_only", | |
| "ema_inactive", | |
| "knn_magnitude", | |
| "graph_diffusion", | |
| ] | |
| print(f"\nInactive-update prediction diagnostic") | |
| print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}") | |
| print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps} policy={args.policy}\n") | |
| print(f"{'mode':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'inactive_mse+':>13s}") | |
| print("-" * 70) | |
| for mode in modes: | |
| result = run_experiment( | |
| mode=mode, | |
| device=args.device, | |
| steps=args.steps, | |
| batch_size=args.batch_size, | |
| block_size=args.block_size, | |
| n_layer=args.n_layer, | |
| n_head=args.n_head, | |
| n_embd=args.n_embd, | |
| chunk_size=args.chunk_size, | |
| active_fraction=args.active_fraction, | |
| warmup_steps=args.warmup_steps, | |
| policy=args.policy, | |
| benchmark_sync=args.benchmark_sync, | |
| ) | |
| print( | |
| f"{mode:>18s} | " | |
| f"{result['val']:8.4f} | " | |
| f"{result['ms']:8.2f} | " | |
| f"{result['cos']:8.3f} | " | |
| f"{result['mse_red']:13.3f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |