#!/usr/bin/env python3 """Benchmark: AdamW vs Adafactor vs Muon on WrinkleBrane. All runs use FP32 baseline config (same as the 10K step run). 500 steps each, same data, same seed for model init. Optimizers: 1. AdamW — current baseline (lr=3e-4, betas=(0.9, 0.95), wd=0.01) 2. Adafactor — memory-efficient adaptive, no external scheduler (PyTorch 2.8) 3. Muon — Momentum + Newton-Schulz orthogonalization, lr=0.05, clip=2.0 DESIGNED FOR JUPYTERLAB TERMINAL — output is tee'd to both stdout and a timestamped log file under logs/ (or --log_dir) so browser crashes don't lose results. Run as: cd /data/WrinkleBrane-Research/09_standalone_model OMP_NUM_THREADS=8 MKL_NUM_THREADS=8 PYTHONPATH=src \\ nohup python3 -u benchmark_optimizers.py & tail -f logs/benchmark_.log """ from __future__ import annotations import copy import gc import math import os import subprocess import sys import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import argparse import torch from torch import nn, Tensor sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from wrinklebrane.standalone_model import WrinkleBraneConfig, WrinkleBraneModel from wrinklebrane.data import load_train_val, VOCAB_SIZE # ============================================================================ # Tee: write to stdout AND a log file simultaneously # ============================================================================ class Tee: """Duplicate stdout writes to a log file. Replaces sys.stdout so all print() calls go to both the terminal and a persistent log file — survives browser/VSCode crashes. """ def __init__(self, log_path: str): self.terminal = sys.__stdout__ os.makedirs(os.path.dirname(log_path), exist_ok=True) self.log = open(log_path, "w", buffering=1) # line-buffered print(f" Logging to: {log_path}", file=self.terminal) def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): self.terminal.flush() self.log.flush() def close(self): self.log.close() sys.stdout = self.terminal # ============================================================================ # Configuration # ============================================================================ BENCHMARK_STEPS = 500 BATCH_SIZE = 16 SEQ_LEN = 128 WARMUP = 50 LOG_EVERY = 50 EVAL_EVERY = 100 SEED = 42 def make_config() -> WrinkleBraneConfig: """Same optimal config as the 10K training run.""" return WrinkleBraneConfig( vocab_size=VOCAB_SIZE, d_model=128, n_layers=6, n_heads=4, L=16, K=SEQ_LEN, code_init="hadamard", learnable_codes=True, temperature=0.5, ffn_expansion=4, use_gated_ffn=True, max_seq_len=SEQ_LEN, dropout=0.1, ortho_lambda=0.01, persistence_lambda=0.99, weight_tying=True, ) # ============================================================================ # Muon Optimizer — Momentum with Newton-Schulz Orthogonalization # ============================================================================ class Muon(torch.optim.Optimizer): """Muon: Momentum + Orthogonalized Updates via Newton-Schulz. For matrix-shaped parameters (2D), Muon orthogonalizes the momentum buffer using Newton-Schulz iterations before applying the update. This naturally preserves orthogonal structure in weight matrices. For non-matrix parameters (1D scalars, biases, embeddings), falls back to standard momentum SGD. Particularly suited for WrinkleBrane because: - Codebook parameters benefit from orthogonality preservation - Projection matrices (W_v, W_q, W_o) stay well-conditioned - The Newton-Schulz step acts like a natural preconditioner Reference: Keller Jordan et al., "Muon: An optimizer for hidden layers" Parameters ---------- params : iterable Model parameters. lr : float Learning rate (default: 0.02, typically higher than Adam). momentum : float Momentum coefficient (default: 0.95). ns_steps : int Number of Newton-Schulz orthogonalization iterations (default: 5). weight_decay : float Decoupled weight decay (default: 0.0). """ def __init__( self, params, lr: float = 0.02, momentum: float = 0.95, ns_steps: int = 5, weight_decay: float = 0.0, ): defaults = dict( lr=lr, momentum=momentum, ns_steps=ns_steps, weight_decay=weight_decay, ) super().__init__(params, defaults) @staticmethod def newton_schulz_orthogonalize(M: Tensor, steps: int = 5) -> Tensor: """Orthogonalize matrix M using Newton-Schulz iteration. Computes the polar factor of M: the nearest orthogonal matrix. Uses the iteration: X_{k+1} = 1.5 * X_k - 0.5 * X_k @ X_k^T @ X_k For efficiency, operates on the smaller dimension. Parameters ---------- M : Tensor [m, n] steps : int Number of NS iterations (5 is usually sufficient). Returns ------- Tensor [m, n] Orthogonalized matrix. """ m, n = M.shape transpose = False # Work with the smaller dimension for efficiency if m < n: M = M.T transpose = True # Normalize to unit spectral norm (approximate) X = M / (M.norm() + 1e-7) # Newton-Schulz iterations: X = 1.5*X - 0.5*X@X^T@X for _ in range(steps): A = X @ X.T X = 1.5 * X - 0.5 * A @ X if transpose: X = X.T return X @torch.no_grad() def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] momentum = group["momentum"] ns_steps = group["ns_steps"] wd = group["weight_decay"] for p in group["params"]: if p.grad is None: continue grad = p.grad # Decoupled weight decay if wd > 0: p.mul_(1 - lr * wd) # Get or create momentum buffer state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(p) buf = state["momentum_buffer"] buf.mul_(momentum).add_(grad) # For 2D parameters: apply Newton-Schulz orthogonalization if p.dim() == 2 and min(p.shape) > 1: update = self.newton_schulz_orthogonalize(buf, steps=ns_steps) # Scale update to match the gradient magnitude update = update * (buf.norm() / (update.norm() + 1e-7)) else: # For 1D/scalar params: standard momentum update = buf p.add_(update, alpha=-lr) return loss # ============================================================================ # Training infrastructure # ============================================================================ @dataclass class BenchmarkResult: name: str steps: int total_time: float losses: List[float] eval_losses: List[Tuple[int, float]] tok_per_sec: List[float] grad_norms: List[float] @property def avg_tok_per_sec(self) -> float: return sum(self.tok_per_sec) / len(self.tok_per_sec) if self.tok_per_sec else 0 @property def final_loss(self) -> float: return self.losses[-1] if self.losses else float("nan") @property def final_eval_loss(self) -> float: return self.eval_losses[-1][1] if self.eval_losses else float("nan") def evaluate(model, val_corpus, seq_len, batch_size=16, n_batches=10, device=None): """Evaluate on validation data.""" if device is None: device = next(model.parameters()).device model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for _ in range(n_batches): inp, tgt = val_corpus.get_batch(batch_size, seq_len) inp, tgt = inp.to(device), tgt.to(device) logits = model(inp) B, T, V = logits.shape loss = nn.functional.cross_entropy( logits.reshape(B * T, V), tgt.reshape(B * T), reduction="sum", ) total_loss += loss.item() total_tokens += B * T return total_loss / total_tokens def run_benchmark( name: str, model: WrinkleBraneModel, config: WrinkleBraneConfig, optimizer: torch.optim.Optimizer, scheduler, train_corpus, val_corpus, steps: int = BENCHMARK_STEPS, grad_clip: float = 1.0, device: torch.device = None, ) -> BenchmarkResult: """Run a training benchmark for the given number of steps.""" if device is None: device = next(model.parameters()).device print(f"\n{'='*70}") print(f" BENCHMARK: {name}") print(f" {steps} steps, batch_size={BATCH_SIZE}, seq_len={SEQ_LEN}") print(f"{'='*70}") param_count = sum(p.numel() for p in model.parameters()) print(f" Parameters: {param_count:,}") # Initial eval init_eval = evaluate(model, val_corpus, SEQ_LEN) print(f" Initial eval loss: {init_eval:.4f} (PPL {math.exp(min(init_eval, 20)):.2f})") # Tracking losses = [] eval_losses = [(0, init_eval)] tok_per_sec_list = [] grad_norms_list = [] running_loss = 0.0 running_tokens = 0 interval_start = time.time() total_start = time.time() print(f"\n Training started at {time.strftime('%H:%M:%S')}") print(f" {'─'*64}") for step in range(1, steps + 1): model.train() optimizer.zero_grad() inp, tgt = train_corpus.get_batch(BATCH_SIZE, SEQ_LEN) inp, tgt = inp.to(device), tgt.to(device) # Forward logits = model(inp) B, T, V = logits.shape task_loss = nn.functional.cross_entropy( logits.reshape(B * T, V), tgt.reshape(B * T), ) # Ortho regularization ortho = config.ortho_lambda * model.ortho_loss() total_loss = task_loss + ortho # Backward total_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() if scheduler is not None: scheduler.step() running_loss += task_loss.item() * B * T running_tokens += B * T grad_norms_list.append(grad_norm.item()) # Log if step % LOG_EVERY == 0: now = time.time() elapsed = now - interval_start avg_loss = running_loss / running_tokens tps = running_tokens / elapsed losses.append(avg_loss) tok_per_sec_list.append(tps) # Get current LR if scheduler is not None: lr = scheduler.get_last_lr()[0] else: lr = optimizer.param_groups[0]["lr"] print(f" step {step:4d}/{steps} | " f"loss={avg_loss:.4f} ppl={math.exp(min(avg_loss, 20)):7.2f} | " f"lr={lr:.2e} gnorm={grad_norm:.2f} | " f"{tps:6.0f} tok/s", flush=True) running_loss = 0.0 running_tokens = 0 interval_start = time.time() # Eval if step % EVAL_EVERY == 0: val_loss = evaluate(model, val_corpus, SEQ_LEN) eval_losses.append((step, val_loss)) print(f" >>> EVAL step {step}: loss={val_loss:.4f}, " f"ppl={math.exp(min(val_loss, 20)):.2f}", flush=True) total_time = time.time() - total_start # Final eval final_eval = evaluate(model, val_corpus, SEQ_LEN) eval_losses.append((steps, final_eval)) print(f"\n {'─'*64}") print(f" DONE: {name}") print(f" Total time: {total_time:.1f}s ({total_time/60:.1f} min)") print(f" Final train loss: {losses[-1]:.4f}" if losses else " No losses recorded") print(f" Final eval loss: {final_eval:.4f} (PPL {math.exp(min(final_eval, 20)):.2f})") print(f" Avg throughput: {sum(tok_per_sec_list)/len(tok_per_sec_list):.0f} tok/s" if tok_per_sec_list else "") return BenchmarkResult( name=name, steps=steps, total_time=total_time, losses=losses, eval_losses=eval_losses, tok_per_sec=tok_per_sec_list, grad_norms=grad_norms_list, ) def make_cosine_scheduler(optimizer, warmup, total_steps): """Cosine LR schedule with linear warmup.""" def lr_schedule(step): if step < warmup: return step / warmup progress = (step - warmup) / max(total_steps - warmup, 1) return 0.5 * (1.0 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) # ============================================================================ # Main benchmark # ============================================================================ def main(): parser = argparse.ArgumentParser(description="WrinkleBrane optimizer benchmark") parser.add_argument( "--log_dir", default=os.path.join(os.path.dirname(__file__), "logs"), help="Directory for log files (default: ./logs/)", ) parser.add_argument( "--steps", type=int, default=BENCHMARK_STEPS, help=f"Training steps per optimizer (default: {BENCHMARK_STEPS})", ) parser.add_argument( "--device", type=str, default=None, help="Device: 'cuda', 'cuda:0', 'cpu'. Auto-detects if not set.", ) args = parser.parse_args() # Redirect stdout to Tee (stdout + log file) — survives browser crashes timestamp = time.strftime("%Y%m%d_%H%M%S") log_path = os.path.join(args.log_dir, f"benchmark_optimizers_{timestamp}.log") tee = Tee(log_path) sys.stdout = tee # Device setup if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("=" * 70) print(" WrinkleBrane Optimizer Benchmark") print(" AdamW vs Adafactor vs Muon — 500 steps each (FP32) [v2: fixed configs]") print("=" * 70) print() # System info print(" Hardware:") cpu_model = subprocess.check_output( "grep 'model name' /proc/cpuinfo | head -1 | cut -d: -f2", shell=True, text=True, ).strip() print(f" CPU: {cpu_model}") print(f" Device: {device}") if device.type == "cuda": print(f" GPU: {torch.cuda.get_device_name(device)}") print(f" VRAM: {torch.cuda.get_device_properties(device).total_memory / 1e9:.1f} GB") # Set thread count properly (CPU only; on GPU, threads matter less) n_threads = int(subprocess.check_output("nproc", text=True).strip()) torch.set_num_threads(n_threads) os.environ["OMP_NUM_THREADS"] = str(n_threads) os.environ["MKL_NUM_THREADS"] = str(n_threads) print(f" Torch threads: {n_threads}") print(f" PyTorch: {torch.__version__}") print() # Load data once print(" Loading data...") train_corpus, val_corpus = load_train_val( "/data/WrinkleBrane-Research/raw" ) print() results = [] # ──────────────────────────────────────────────────────────────────── # Benchmark 1: AdamW (baseline) # ──────────────────────────────────────────────────────────────────── print(" ┌─────────────────────────────────────┐") print(" │ 1/3: AdamW (baseline) │") print(" └─────────────────────────────────────┘") torch.manual_seed(SEED) config = make_config() model = WrinkleBraneModel(config).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95), ) scheduler = make_cosine_scheduler(optimizer, WARMUP, BENCHMARK_STEPS) result = run_benchmark( name="AdamW (lr=3e-4, β=(0.9,0.95), wd=0.01)", model=model, config=config, optimizer=optimizer, scheduler=scheduler, train_corpus=train_corpus, val_corpus=val_corpus, device=device, ) results.append(result) del model, optimizer, scheduler gc.collect() # ──────────────────────────────────────────────────────────────────── # Benchmark 2: Adafactor # ──────────────────────────────────────────────────────────────────── print("\n ┌─────────────────────────────────────┐") print(" │ 2/3: Adafactor │") print(" └─────────────────────────────────────┘") torch.manual_seed(SEED) config = make_config() model = WrinkleBraneModel(config).to(device) # Adafactor config (v2 — fixed): # - lr=1e-3 constant, NO external scheduler # - Adafactor manages its own adaptive second moment via beta2_decay # - Layering a cosine schedule on top (v1) was double-scheduling it, # fighting the internal adaptive rate and preventing convergence # - eps=(None, 1e-3): no row-factor (None), column epsilon=1e-3 # - weight_decay=0.01 (match AdamW) optimizer = torch.optim.Adafactor( model.parameters(), lr=1e-3, beta2_decay=-0.8, eps=(None, 1e-3), weight_decay=0.01, ) scheduler = None # Adafactor manages its own schedule internally result = run_benchmark( name="Adafactor (lr=1e-3, no sched, β2d=-0.8, wd=0.01)", model=model, config=config, optimizer=optimizer, scheduler=scheduler, train_corpus=train_corpus, val_corpus=val_corpus, device=device, ) results.append(result) del model, optimizer, scheduler gc.collect() # ──────────────────────────────────────────────────────────────────── # Benchmark 3: Muon # ──────────────────────────────────────────────────────────────────── print("\n ┌─────────────────────────────────────┐") print(" │ 3/3: Muon │") print(" └─────────────────────────────────────┘") torch.manual_seed(SEED) config = make_config() model = WrinkleBraneModel(config).to(device) # Muon config (v2 — fixed): # - lr=0.05 (v1's 0.02 plateaued hard after step 150; orthogonalized # updates are naturally bounded so Muon can push much harder) # - momentum=0.95 (standard heavy momentum) # - ns_steps=5 (Newton-Schulz iterations, standard) # - weight_decay=0.01 (decoupled, match others) # - grad_clip=2.0 (v1's avg grad norm was 0.474 — well under 1.0, # so the clip was never helping; raise ceiling to let Muon breathe) optimizer = Muon( model.parameters(), lr=0.05, momentum=0.95, ns_steps=5, weight_decay=0.01, ) scheduler = make_cosine_scheduler(optimizer, WARMUP, BENCHMARK_STEPS) result = run_benchmark( name="Muon (lr=0.05, mom=0.95, ns=5, wd=0.01, clip=2.0)", model=model, config=config, optimizer=optimizer, scheduler=scheduler, train_corpus=train_corpus, val_corpus=val_corpus, grad_clip=2.0, device=device, ) results.append(result) del model, optimizer, scheduler gc.collect() # ──────────────────────────────────────────────────────────────────── # Results comparison # ──────────────────────────────────────────────────────────────────── print("\n\n") print("=" * 70) print(" OPTIMIZER BENCHMARK RESULTS") print("=" * 70) print() # Summary table header = f" {'Metric':<28}" sep = f" {'─'*28}" for r in results: short_name = r.name.split("(")[0].strip() header += f" │ {short_name:<26}" sep += f" │{'─'*26}" print(header) print(sep) # Total time row = f" {'Total time':<28}" for r in results: row += f" │ {r.total_time:>7.1f}s ({r.total_time/60:>4.1f}m) " print(row) # Time per step row = f" {'Time per step':<28}" for r in results: row += f" │ {r.total_time/r.steps*1000:>7.1f} ms/step " print(row) # Throughput row = f" {'Avg throughput':<28}" for r in results: row += f" │ {r.avg_tok_per_sec:>7.0f} tok/s " print(row) # Final train loss row = f" {'Final train loss':<28}" for r in results: row += f" │ {r.final_loss:>7.4f} " print(row) # Final eval loss row = f" {'Final eval loss':<28}" for r in results: row += f" │ {r.final_eval_loss:>7.4f} " print(row) # Final eval PPL row = f" {'Final eval PPL':<28}" for r in results: ppl = math.exp(min(r.final_eval_loss, 20)) row += f" │ {ppl:>7.2f} " print(row) # Avg gradient norm row = f" {'Avg gradient norm':<28}" for r in results: avg_gn = sum(r.grad_norms) / len(r.grad_norms) if r.grad_norms else 0 row += f" │ {avg_gn:>7.3f} " print(row) # Peak gradient norm row = f" {'Peak gradient norm':<28}" for r in results: peak_gn = max(r.grad_norms) if r.grad_norms else 0 row += f" │ {peak_gn:>7.3f} " print(row) # Best eval loss (with step) row = f" {'Best eval loss (step)':<28}" for r in results: best_step, best_loss = min(r.eval_losses, key=lambda x: x[1]) row += f" │ {best_loss:.4f} @ {best_step:<4} " print(row) # Relative comparisons print(f"\n{sep}") base_eval = results[0].final_eval_loss base_time = results[0].total_time row = f" {'Eval loss vs AdamW':<28}" for r in results: delta = r.final_eval_loss - base_eval pct = (delta / base_eval) * 100 sign = "+" if delta >= 0 else "" row += f" │ {sign}{delta:>6.4f} ({sign}{pct:.1f}%) " print(row) row = f" {'Time vs AdamW':<28}" for r in results: ratio = r.total_time / base_time row += f" │ {ratio:>6.2f}x " print(row) # ──────────────────────────────────────────────────────────────────── # Loss curves # ──────────────────────────────────────────────────────────────────── print(f"\n\n EVAL LOSS CURVES") header = f" {'Step':<8}" sep2 = f" {'─'*8}" for r in results: short_name = r.name.split("(")[0].strip() header += f" │ {short_name:<26}" sep2 += f" │{'─'*26}" print(header) print(sep2) all_steps = sorted(set(s for r in results for s, _ in r.eval_losses)) for step in all_steps: row = f" {step:<8}" for r in results: val = next((l for s, l in r.eval_losses if s == step), None) if val is not None: ppl = math.exp(min(val, 20)) row += f" │ {val:>6.4f} (PPL {ppl:>6.2f}) " else: row += f" │ {'—':>24} " print(row) # ──────────────────────────────────────────────────────────────────── # Training loss curves (per LOG_EVERY) # ──────────────────────────────────────────────────────────────────── print(f"\n\n TRAINING LOSS CURVES") header = f" {'Step':<8}" for r in results: short_name = r.name.split("(")[0].strip() header += f" │ {short_name:<26}" print(header) print(sep2) max_entries = max(len(r.losses) for r in results) for i in range(max_entries): step = (i + 1) * LOG_EVERY row = f" {step:<8}" for r in results: if i < len(r.losses): ppl = math.exp(min(r.losses[i], 20)) row += f" │ {r.losses[i]:>6.4f} (PPL {ppl:>6.2f}) " else: row += f" │ {'—':>24} " print(row) # Winner announcement print(f"\n\n{'='*70}") best_result = min(results, key=lambda r: r.final_eval_loss) best_ppl = math.exp(min(best_result.final_eval_loss, 20)) print(f" WINNER: {best_result.name}") print(f" Final eval loss: {best_result.final_eval_loss:.4f} (PPL {best_ppl:.2f})") improvement = (base_eval - best_result.final_eval_loss) / base_eval * 100 if best_result.name != results[0].name: print(f" Improvement over AdamW: {improvement:.2f}%") print(f"{'='*70}") print(f"\n Log saved to: {log_path}", flush=True) tee.close() if __name__ == "__main__": main()