| """ |
| Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz. |
| |
| Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic |
| Newton-Schulz iteration before the update. It consistently beats AdamW on |
| transformer bodies, especially on reasoning-heavy benchmarks, at negligible |
| extra cost. |
| |
| Standard practice (and what we use here): Muon for the 2D transformer matrices, |
| AdamW for everything else — embeddings, the LM head, and all 1D params (norms). |
| Since İvme ties its embeddings, the shared embed/head table goes to AdamW. |
| |
| Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch.optim import AdamW |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor: |
| """Compute an approximate orthogonalization of G via a quintic NS iteration. |
| |
| The coefficients (a, b, c) are tuned so the iteration pushes the singular |
| values of G toward 1 without ever computing an SVD. Runs in bf16 for speed. |
| """ |
| assert G.ndim == 2 |
| a, b, c = (3.4445, -4.7750, 2.0315) |
| X = G.bfloat16() |
| transposed = G.size(0) > G.size(1) |
| if transposed: |
| X = X.T |
| |
| X = X / (X.norm() + 1e-7) |
| for _ in range(steps): |
| A = X @ X.T |
| B = b * A + c * (A @ A) |
| X = a * X + B @ X |
| if transposed: |
| X = X.T |
| return X |
|
|
|
|
| |
| |
| |
| class Muon(torch.optim.Optimizer): |
| def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): |
| defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) |
| super().__init__(params, defaults) |
|
|
| @torch.no_grad() |
| def step(self): |
| for group in self.param_groups: |
| lr = group["lr"] |
| momentum = group["momentum"] |
| nesterov = group["nesterov"] |
| ns_steps = group["ns_steps"] |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| g = p.grad |
| state = self.state[p] |
| if "momentum_buffer" not in state: |
| state["momentum_buffer"] = torch.zeros_like(g) |
| buf = state["momentum_buffer"] |
| buf.mul_(momentum).add_(g) |
| g = g.add(buf, alpha=momentum) if nesterov else buf |
| g = zeropower_via_newtonschulz5(g, steps=ns_steps) |
| |
| |
| scale = max(1.0, g.size(0) / g.size(1)) ** 0.5 |
| p.add_(g.to(p.dtype), alpha=-lr * scale) |
|
|
|
|
| |
| |
| |
| def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1, |
| betas=(0.9, 0.95)): |
| """Split İvme's params into Muon (2D transformer matrices) and AdamW (rest). |
| |
| Returns (muon, adamw). Step both each iteration; schedule both together. |
| """ |
| muon_params, adamw_params = [], [] |
| for name, p in model.named_parameters(): |
| if not p.requires_grad: |
| continue |
| |
| |
| is_body_matrix = ( |
| p.ndim == 2 |
| and "embed" not in name |
| and "lm_head" not in name |
| ) |
| (muon_params if is_body_matrix else adamw_params).append(p) |
|
|
| muon = Muon(muon_params, lr=muon_lr) |
| adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay) |
|
|
| n_muon = sum(p.numel() for p in muon_params) |
| n_adamw = sum(p.numel() for p in adamw_params) |
| print(f"[optim] Muon : {len(muon_params)} tensors, {n_muon:,} params") |
| print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params") |
| return muon, adamw |
|
|
|
|
| |
| |
| |
| def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100, |
| decay_frac: float = 0.2) -> float: |
| """Warmup-Stable-Decay multiplier in [0, 1]. |
| |
| Linear warmup -> constant stable phase -> linear decay to ~0 over the final |
| `decay_frac` of training. Multiply each optimizer's base lr by this value. |
| """ |
| decay_start = int(total_steps * (1 - decay_frac)) |
| if step < warmup: |
| return step / max(1, warmup) |
| if step < decay_start: |
| return 1.0 |
| |
| progress = (step - decay_start) / max(1, total_steps - decay_start) |
| return max(0.0, 1.0 - progress) |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| torch.manual_seed(0) |
| G = torch.randn(384, 1024) |
| Q = zeropower_via_newtonschulz5(G).float() |
| |
| I = Q @ Q.T |
| err = (I - torch.eye(Q.size(0))).abs().mean().item() |
| print(f"[ns] orthogonality error (lower=better): {err:.4f}") |
|
|
| |
| total = 6000 |
| pts = [0, 50, 100, 1000, 4800, 5400, 5999] |
| print("[wsd] step -> lr_mult") |
| for s in pts: |
| print(f" {s:>5} -> {wsd_lr_multiplier(s, total):.3f}") |
|
|