Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | """ | |
| Sparse Transformer v16: Sensor-Based Mask Scheduling. | |
| v15 showed that directly hallucinating inactive gradient vectors was harmful. | |
| v16 tests the safer next idea: | |
| Use active chunks as sensors to choose which chunks receive real gradients next. | |
| No inactive gradient is invented. In sparse modes, inactive chunks get zero gradient. | |
| The only question is whether active chunk observations improve future mask selection. | |
| Schedulers: | |
| dense | |
| Dense baseline. | |
| ema_topk | |
| Select top chunks by each chunk's own EMA gradient mass. | |
| knn_scheduler | |
| Use active chunks as sensors. Predict next-step inactive chunk mass from | |
| historically correlated active chunks. Select next mask from that score. | |
| graph_scheduler | |
| Boundary-value style magnitude diffusion over a chunk similarity graph. | |
| Active chunks are clamped to observed magnitudes. Inactive magnitudes are | |
| interpolated and used to choose the next mask. | |
| random | |
| Random sparse-support control. | |
| This is still a diagnostic/simulation script: it computes dense gradients so we can | |
| measure oracle Jaccard/cosine, then installs only the selected active chunk gradients | |
| for sparse training. | |
| Run: | |
| python3 sparse_transformer_v16_sensor_scheduler.py --device mps --benchmark_sync | |
| Useful: | |
| python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 512 | |
| python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 1024 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import random | |
| import time | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| import torch | |
| torch.set_num_threads(1) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "graph_scheduler", "random"] | |
| def sync_device(device: str) -> None: | |
| if device == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elif device == "mps" and hasattr(torch, "mps"): | |
| torch.mps.synchronize() | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| def make_cpu_generator(seed: int) -> torch.Generator: | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed(seed) | |
| return gen | |
| # ----------------------------- | |
| # Data | |
| # ----------------------------- | |
| def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str: | |
| rng = random.Random(seed) | |
| words = [ | |
| "ada", "turing", "grace", "lovelace", "gradients", | |
| "tokens", "circuits", "features", "boldly", "strangely", | |
| "matrix", "attention", "kernel", "entropy", "signal", | |
| ] | |
| return "\n".join( | |
| " ".join(rng.choices(words, k=rng.randint(4, 10))) + "." | |
| for _ in range(n_sentences) | |
| ) | |
| class CharCorpus: | |
| def __init__(self, text: str, block_size: int, device: str): | |
| chars = sorted(set(text)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for ch, i in self.stoi.items()} | |
| self.vocab_size = len(chars) | |
| self.block_size = block_size | |
| self.device = device | |
| data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long) | |
| self.train_data = data[: int(0.9 * len(data))] | |
| self.val_data = data[int(0.9 * len(data)) :] | |
| def get_batch( | |
| self, | |
| split: str, | |
| batch_size: int, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| data = self.train_data if split == "train" else self.val_data | |
| ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator) | |
| x = torch.stack([data[i : i + self.block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix]) | |
| return x.to(self.device), y.to(self.device) | |
| # ----------------------------- | |
| # Model | |
| # ----------------------------- | |
| class SparseLinear(nn.Linear): | |
| pass | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = SparseLinear(n_embd, 3 * n_embd) | |
| self.c_proj = SparseLinear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer( | |
| "mask", | |
| torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float): | |
| super().__init__() | |
| self.c_fc = SparseLinear(n_embd, 4 * n_embd) | |
| self.c_proj = SparseLinear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(F.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = FeedForward(n_embd, dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class MiniGPT(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| block_size: int, | |
| n_layer: int, | |
| n_head: int, | |
| n_embd: int, | |
| dropout: float, | |
| ): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.Sequential( | |
| *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)] | |
| ) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.shape | |
| pos = torch.arange(T, device=idx.device) | |
| x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :] | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def get_sparse_linears(model: nn.Module) -> List[SparseLinear]: | |
| return [m for m in model.modules() if isinstance(m, SparseLinear)] | |
| # ----------------------------- | |
| # Chunk map and scheduler | |
| # ----------------------------- | |
| class ChunkScheduler: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| chunk_size: int, | |
| active_fraction: float, | |
| device: str, | |
| scheduler: Scheduler, | |
| mass_beta: float = 0.95, | |
| ): | |
| self.model = model | |
| self.chunk_size = chunk_size | |
| self.active_fraction = active_fraction | |
| self.device = device | |
| self.scheduler = scheduler | |
| self.mass_beta = mass_beta | |
| self.linears = get_sparse_linears(model) | |
| self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {} | |
| self.chunk_to_module_local: List[Tuple[nn.Module, int]] = [] | |
| offset = 0 | |
| for m in self.linears: | |
| assert m.out_features % chunk_size == 0, ( | |
| f"out_features {m.out_features} not divisible by chunk_size {chunk_size}" | |
| ) | |
| n_chunks = m.out_features // chunk_size | |
| ids = torch.arange(offset, offset + n_chunks, device=device) | |
| self.module_to_chunk_ids[m] = ids | |
| for local_c in range(n_chunks): | |
| self.chunk_to_module_local.append((m, local_c)) | |
| offset += n_chunks | |
| self.n_chunks = offset | |
| self.predicted_mass = torch.zeros(self.n_chunks, device=device) | |
| self.mass_history: List[torch.Tensor] = [] | |
| self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device) | |
| self.next_scores = torch.zeros(self.n_chunks, device=device) | |
| self.prev_mask: Optional[torch.Tensor] = None | |
| self.similarity: Optional[torch.Tensor] = None | |
| def k_active(self) -> int: | |
| return max(1, int(self.active_fraction * self.n_chunks)) | |
| def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor: | |
| if self.scheduler == "dense" or step < warmup_steps: | |
| self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device) | |
| return self.current_mask | |
| k = self.k_active() | |
| mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) | |
| if self.scheduler == "random": | |
| idx = torch.randperm(self.n_chunks, device=self.device)[:k] | |
| elif self.scheduler == "ema_topk": | |
| scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass) | |
| idx = torch.topk(scores, k=k).indices | |
| elif self.scheduler in ("knn_scheduler", "graph_scheduler"): | |
| # next_scores are computed from the previous step's active sensors. | |
| # If unavailable, fall back to EMA. | |
| base = self.next_scores | |
| if torch.count_nonzero(base).item() == 0: | |
| base = self.predicted_mass | |
| scores = base + 1e-9 * torch.rand_like(base) | |
| idx = torch.topk(scores, k=k).indices | |
| else: | |
| raise ValueError(f"Unknown scheduler: {self.scheduler}") | |
| mask[idx] = True | |
| self.current_mask = mask | |
| return mask | |
| def chunk_gradient_vectors(self) -> List[torch.Tensor]: | |
| vecs: List[torch.Tensor] = [] | |
| for m, local_c in self.chunk_to_module_local: | |
| start = local_c * self.chunk_size | |
| end = (local_c + 1) * self.chunk_size | |
| parts = [] | |
| if m.weight.grad is None: | |
| parts.append(torch.zeros_like(m.weight[start:end]).flatten()) | |
| else: | |
| parts.append(m.weight.grad[start:end].detach().flatten()) | |
| if m.bias is not None: | |
| if m.bias.grad is None: | |
| parts.append(torch.zeros_like(m.bias[start:end]).flatten()) | |
| else: | |
| parts.append(m.bias.grad[start:end].detach().flatten()) | |
| vecs.append(torch.cat(parts)) | |
| return vecs | |
| def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor: | |
| return torch.stack([v.norm() for v in vecs]).to(self.device) | |
| def update_from_observed( | |
| self, | |
| active_mask: torch.Tensor, | |
| true_masses: torch.Tensor, | |
| step: int, | |
| warmup_steps: int, | |
| ) -> None: | |
| observed = active_mask | |
| never_seen = observed & (self.predicted_mass == 0) | |
| already_seen = observed & ~never_seen | |
| self.predicted_mass[never_seen] = true_masses[never_seen] | |
| self.predicted_mass[already_seen] = ( | |
| self.mass_beta * self.predicted_mass[already_seen] | |
| + (1.0 - self.mass_beta) * true_masses[already_seen] | |
| ) | |
| # During warmup we store dense mass histories to learn the similarity graph. | |
| if step < warmup_steps: | |
| self.mass_history.append(true_masses.detach().clone()) | |
| max_hist = 128 | |
| if len(self.mass_history) > max_hist: | |
| self.mass_history = self.mass_history[-max_hist:] | |
| if len(self.mass_history) >= 8: | |
| self.similarity = self.build_similarity() | |
| # Compute next_scores from current active observations. | |
| if self.scheduler == "knn_scheduler": | |
| self.next_scores = self.knn_scores(active_mask, true_masses) | |
| elif self.scheduler == "graph_scheduler": | |
| self.next_scores = self.diffusion_scores(active_mask, true_masses) | |
| else: | |
| self.next_scores = self.predicted_mass.clone() | |
| def layer_allowed_mask(self) -> torch.Tensor: | |
| allowed = torch.zeros((self.n_chunks, self.n_chunks), dtype=torch.bool, device=self.device) | |
| for _, ids in self.module_to_chunk_ids.items(): | |
| allowed |= ids[:, None].eq(ids[None, :]) # placeholder overwritten below | |
| allowed.zero_() | |
| for _, ids in self.module_to_chunk_ids.items(): | |
| allowed[ids[:, None], ids[None, :]] = True | |
| return allowed | |
| def build_similarity(self) -> torch.Tensor: | |
| H = torch.stack(self.mass_history, dim=0) # [history, chunks] | |
| H = H - H.mean(dim=0, keepdim=True) | |
| H = H / (H.std(dim=0, keepdim=True) + 1e-6) | |
| S = (H.T @ H) / max(1, H.shape[0] - 1) | |
| S = torch.clamp(S, min=0.0) | |
| S.fill_diagonal_(0.0) | |
| # Keep only within-layer similarities. Cross-layer correlation is too easy | |
| # to overfit in this tiny diagnostic. | |
| allowed = torch.zeros_like(S, dtype=torch.bool) | |
| for _, ids in self.module_to_chunk_ids.items(): | |
| allowed[ids[:, None], ids[None, :]] = True | |
| S = torch.where(allowed, S, torch.zeros_like(S)) | |
| return S | |
| def knn_scores(self, active_mask: torch.Tensor, true_masses: torch.Tensor, k_neighbors: int = 3) -> torch.Tensor: | |
| if self.similarity is None: | |
| return self.predicted_mass.clone() | |
| S = self.similarity | |
| scores = self.predicted_mass.clone() | |
| scores[active_mask] = true_masses[active_mask] | |
| active_idx = torch.nonzero(active_mask, as_tuple=False).flatten() | |
| inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten() | |
| if active_idx.numel() == 0: | |
| return scores | |
| for i in inactive_idx.tolist(): | |
| weights = S[i, active_idx] | |
| if weights.sum() <= 1e-12: | |
| continue | |
| kk = min(k_neighbors, weights.numel()) | |
| top = torch.topk(weights, k=kk) | |
| w = top.values | |
| aidx = active_idx[top.indices] | |
| scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12) | |
| return scores | |
| def diffusion_scores( | |
| self, | |
| active_mask: torch.Tensor, | |
| true_masses: torch.Tensor, | |
| diffusion_steps: int = 8, | |
| alpha: float = 0.7, | |
| ) -> torch.Tensor: | |
| if self.similarity is None: | |
| return self.predicted_mass.clone() | |
| S = self.similarity | |
| W = S / (S.sum(dim=1, keepdim=True) + 1e-12) | |
| scores = self.predicted_mass.clone() | |
| scores[active_mask] = true_masses[active_mask] | |
| for _ in range(diffusion_steps): | |
| proposal = W @ scores | |
| scores = alpha * proposal + (1.0 - alpha) * scores | |
| scores[active_mask] = true_masses[active_mask] | |
| return torch.clamp(scores, min=0.0) | |
| def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor: | |
| k = self.k_active() | |
| mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device) | |
| mask[torch.topk(true_masses, k=k).indices] = True | |
| return mask | |
| # ----------------------------- | |
| # Gradient installation and metrics | |
| # ----------------------------- | |
| def install_active_only_grads(sched: ChunkScheduler, active_mask: torch.Tensor) -> None: | |
| if sched.scheduler == "dense": | |
| return | |
| for m, ids in sched.module_to_chunk_ids.items(): | |
| local_active = active_mask[ids] | |
| if m.weight.grad is not None: | |
| for local_c, is_active in enumerate(local_active.tolist()): | |
| if not is_active: | |
| start = local_c * sched.chunk_size | |
| end = (local_c + 1) * sched.chunk_size | |
| m.weight.grad[start:end].zero_() | |
| if m.bias is not None and m.bias.grad is not None: | |
| for local_c, is_active in enumerate(local_active.tolist()): | |
| if not is_active: | |
| start = local_c * sched.chunk_size | |
| end = (local_c + 1) * sched.chunk_size | |
| m.bias.grad[start:end].zero_() | |
| def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float: | |
| true = torch.cat([v.flatten() for v in vecs]) | |
| approx_parts = [] | |
| for i, v in enumerate(vecs): | |
| approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten()) | |
| approx = torch.cat(approx_parts) | |
| return float(F.cosine_similarity(true, approx, dim=0).item()) | |
| def jaccard(a: torch.Tensor, b: torch.Tensor) -> float: | |
| inter = (a & b).sum().float() | |
| union = (a | b).sum().float() | |
| return float((inter / torch.clamp(union, min=1.0)).item()) | |
| class SimpleAdam: | |
| def __init__(self, model: nn.Module, lr: float = 3e-4): | |
| self.model = model | |
| self.lr = lr | |
| self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {} | |
| def zero_grad(self): | |
| for p in self.model.parameters(): | |
| p.grad = None | |
| def step(self): | |
| for p in self.model.parameters(): | |
| if p.grad is None: | |
| continue | |
| if p not in self.state: | |
| self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)} | |
| m = self.state[p]["m"] | |
| v = self.state[p]["v"] | |
| m.mul_(0.9).add_(p.grad, alpha=0.1) | |
| v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001) | |
| p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr) | |
| def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float: | |
| model.eval() | |
| with torch.no_grad(): | |
| x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed)) | |
| _, loss = model(x, y) | |
| model.train() | |
| return float(loss.item()) | |
| def run_experiment( | |
| scheduler_name: Scheduler, | |
| device: str, | |
| steps: int, | |
| batch_size: int, | |
| block_size: int, | |
| n_layer: int, | |
| n_head: int, | |
| n_embd: int, | |
| chunk_size: int, | |
| active_fraction: float, | |
| warmup_steps: int, | |
| benchmark_sync: bool, | |
| ) -> Dict[str, float]: | |
| set_seed(42) | |
| corpus = CharCorpus(make_synthetic_corpus(), block_size, device) | |
| model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device) | |
| opt = SimpleAdam(model, lr=3e-4) | |
| sched = ChunkScheduler( | |
| model=model, | |
| chunk_size=chunk_size, | |
| active_fraction=active_fraction, | |
| device=device, | |
| scheduler=scheduler_name, | |
| ) | |
| metric_rows = [] | |
| if benchmark_sync: | |
| sync_device(device) | |
| t0 = time.perf_counter() | |
| for step in range(steps): | |
| x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step)) | |
| active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps) | |
| opt.zero_grad() | |
| _, loss = model(x, y) | |
| loss.backward() | |
| vecs = sched.chunk_gradient_vectors() | |
| masses = sched.chunk_masses_from_vecs(vecs) | |
| if step >= warmup_steps and scheduler_name != "dense": | |
| oracle = sched.oracle_topk_mask(masses) | |
| row = { | |
| "cos": dense_cosine_active_only(vecs, active_mask), | |
| "jacc": jaccard(active_mask, oracle), | |
| "stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0, | |
| "val": evaluate(model, corpus, batch_size, seed=10_000 + step) if step % 50 == 0 else float("nan"), | |
| } | |
| metric_rows.append(row) | |
| install_active_only_grads(sched, active_mask) | |
| # Important: update scheduler from the active observations only. | |
| # Dense gradients exist for diagnostics, but unselected chunks should not | |
| # teach the sparse scheduler after warmup. | |
| observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask) | |
| sched.update_from_observed( | |
| active_mask=observed_for_scheduler, | |
| true_masses=masses, | |
| step=step, | |
| warmup_steps=warmup_steps, | |
| ) | |
| sched.prev_mask = active_mask.clone() | |
| opt.step() | |
| if benchmark_sync: | |
| sync_device(device) | |
| elapsed = time.perf_counter() - t0 | |
| val_loss = evaluate(model, corpus, batch_size, seed=12345) | |
| if metric_rows: | |
| avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows) | |
| avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows) | |
| avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows) | |
| else: | |
| avg_cos = float("nan") | |
| avg_jacc = float("nan") | |
| avg_stable = float("nan") | |
| return { | |
| "val": val_loss, | |
| "ms": 1000.0 * elapsed / steps, | |
| "cos": avg_cos, | |
| "jacc": avg_jacc, | |
| "stable": avg_stable, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--steps", type=int, default=500) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--block_size", type=int, default=128) | |
| parser.add_argument("--n_layer", type=int, default=4) | |
| parser.add_argument("--n_head", type=int, default=8) | |
| parser.add_argument("--n_embd", type=int, default=512) | |
| parser.add_argument("--chunk_size", type=int, default=64) | |
| parser.add_argument("--active_fraction", type=float, default=0.10) | |
| parser.add_argument("--warmup_steps", type=int, default=25) | |
| parser.add_argument("--device", type=str, default="mps") | |
| parser.add_argument("--benchmark_sync", action="store_true") | |
| args = parser.parse_args() | |
| schedulers: List[Scheduler] = [ | |
| "dense", | |
| "ema_topk", | |
| "knn_scheduler", | |
| "graph_scheduler", | |
| "random", | |
| ] | |
| print("\nSensor-based mask scheduling diagnostic") | |
| print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}") | |
| print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}\n") | |
| print(f"{'scheduler':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}") | |
| print("-" * 78) | |
| for sched_name in schedulers: | |
| result = run_experiment( | |
| scheduler_name=sched_name, | |
| device=args.device, | |
| steps=args.steps, | |
| batch_size=args.batch_size, | |
| block_size=args.block_size, | |
| n_layer=args.n_layer, | |
| n_head=args.n_head, | |
| n_embd=args.n_embd, | |
| chunk_size=args.chunk_size, | |
| active_fraction=args.active_fraction, | |
| warmup_steps=args.warmup_steps, | |
| benchmark_sync=args.benchmark_sync, | |
| ) | |
| print( | |
| f"{sched_name:>18s} | " | |
| f"{result['val']:8.4f} | " | |
| f"{result['ms']:8.2f} | " | |
| f"{result['cos']:8.3f} | " | |
| f"{result['jacc']:8.3f} | " | |
| f"{result['stable']:8.3f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |