""" Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz. Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic Newton-Schulz iteration before the update. It consistently beats AdamW on transformer bodies, especially on reasoning-heavy benchmarks, at negligible extra cost. Standard practice (and what we use here): Muon for the 2D transformer matrices, AdamW for everything else — embeddings, the LM head, and all 1D params (norms). Since İvme ties its embeddings, the shared embed/head table goes to AdamW. Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon). """ from __future__ import annotations import torch from torch.optim import AdamW # --------------------------------------------------------------------------- # # Newton-Schulz orthogonalization # --------------------------------------------------------------------------- # @torch.no_grad() def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor: """Compute an approximate orthogonalization of G via a quintic NS iteration. The coefficients (a, b, c) are tuned so the iteration pushes the singular values of G toward 1 without ever computing an SVD. Runs in bf16 for speed. """ assert G.ndim == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() transposed = G.size(0) > G.size(1) if transposed: X = X.T # Normalize so the spectral norm is <= 1 before iterating. X = X / (X.norm() + 1e-7) for _ in range(steps): A = X @ X.T B = b * A + c * (A @ A) X = a * X + B @ X if transposed: X = X.T return X # --------------------------------------------------------------------------- # # Muon # --------------------------------------------------------------------------- # class Muon(torch.optim.Optimizer): def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) @torch.no_grad() def step(self): for group in self.param_groups: lr = group["lr"] momentum = group["momentum"] nesterov = group["nesterov"] ns_steps = group["ns_steps"] for p in group["params"]: if p.grad is None: continue g = p.grad state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) g = g.add(buf, alpha=momentum) if nesterov else buf g = zeropower_via_newtonschulz5(g, steps=ns_steps) # Scale so the update RMS roughly matches the parameter shape; # an orthogonalized matrix has spectral norm ~1 regardless of size. scale = max(1.0, g.size(0) / g.size(1)) ** 0.5 p.add_(g.to(p.dtype), alpha=-lr * scale) # --------------------------------------------------------------------------- # # Hybrid optimizer builder # --------------------------------------------------------------------------- # def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1, betas=(0.9, 0.95)): """Split İvme's params into Muon (2D transformer matrices) and AdamW (rest). Returns (muon, adamw). Step both each iteration; schedule both together. """ muon_params, adamw_params = [], [] for name, p in model.named_parameters(): if not p.requires_grad: continue # 2D weights inside transformer blocks -> Muon. # Embeddings, LM head, and all 1D params (norms) -> AdamW. is_body_matrix = ( p.ndim == 2 and "embed" not in name and "lm_head" not in name ) (muon_params if is_body_matrix else adamw_params).append(p) muon = Muon(muon_params, lr=muon_lr) adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay) n_muon = sum(p.numel() for p in muon_params) n_adamw = sum(p.numel() for p in adamw_params) print(f"[optim] Muon : {len(muon_params)} tensors, {n_muon:,} params") print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params") return muon, adamw # --------------------------------------------------------------------------- # # WSD learning-rate schedule # --------------------------------------------------------------------------- # def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100, decay_frac: float = 0.2) -> float: """Warmup-Stable-Decay multiplier in [0, 1]. Linear warmup -> constant stable phase -> linear decay to ~0 over the final `decay_frac` of training. Multiply each optimizer's base lr by this value. """ decay_start = int(total_steps * (1 - decay_frac)) if step < warmup: return step / max(1, warmup) if step < decay_start: return 1.0 # Linear decay over the final decay_frac of steps. progress = (step - decay_start) / max(1, total_steps - decay_start) return max(0.0, 1.0 - progress) # --------------------------------------------------------------------------- # # Self-test # --------------------------------------------------------------------------- # if __name__ == "__main__": # 1) Newton-Schulz should produce a near-orthogonal matrix. torch.manual_seed(0) G = torch.randn(384, 1024) Q = zeropower_via_newtonschulz5(G).float() # For a wide matrix, Q @ Q.T should be close to identity. I = Q @ Q.T err = (I - torch.eye(Q.size(0))).abs().mean().item() print(f"[ns] orthogonality error (lower=better): {err:.4f}") # 2) WSD schedule shape. total = 6000 pts = [0, 50, 100, 1000, 4800, 5400, 5999] print("[wsd] step -> lr_mult") for s in pts: print(f" {s:>5} -> {wsd_lr_multiplier(s, total):.3f}")