File size: 6,077 Bytes
0d72706 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """
Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz.
Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic
Newton-Schulz iteration before the update. It consistently beats AdamW on
transformer bodies, especially on reasoning-heavy benchmarks, at negligible
extra cost.
Standard practice (and what we use here): Muon for the 2D transformer matrices,
AdamW for everything else — embeddings, the LM head, and all 1D params (norms).
Since İvme ties its embeddings, the shared embed/head table goes to AdamW.
Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon).
"""
from __future__ import annotations
import torch
from torch.optim import AdamW
# --------------------------------------------------------------------------- #
# Newton-Schulz orthogonalization
# --------------------------------------------------------------------------- #
@torch.no_grad()
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
"""Compute an approximate orthogonalization of G via a quintic NS iteration.
The coefficients (a, b, c) are tuned so the iteration pushes the singular
values of G toward 1 without ever computing an SVD. Runs in bf16 for speed.
"""
assert G.ndim == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
transposed = G.size(0) > G.size(1)
if transposed:
X = X.T
# Normalize so the spectral norm is <= 1 before iterating.
X = X / (X.norm() + 1e-7)
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
if transposed:
X = X.T
return X
# --------------------------------------------------------------------------- #
# Muon
# --------------------------------------------------------------------------- #
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
buf.mul_(momentum).add_(g)
g = g.add(buf, alpha=momentum) if nesterov else buf
g = zeropower_via_newtonschulz5(g, steps=ns_steps)
# Scale so the update RMS roughly matches the parameter shape;
# an orthogonalized matrix has spectral norm ~1 regardless of size.
scale = max(1.0, g.size(0) / g.size(1)) ** 0.5
p.add_(g.to(p.dtype), alpha=-lr * scale)
# --------------------------------------------------------------------------- #
# Hybrid optimizer builder
# --------------------------------------------------------------------------- #
def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1,
betas=(0.9, 0.95)):
"""Split İvme's params into Muon (2D transformer matrices) and AdamW (rest).
Returns (muon, adamw). Step both each iteration; schedule both together.
"""
muon_params, adamw_params = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
# 2D weights inside transformer blocks -> Muon.
# Embeddings, LM head, and all 1D params (norms) -> AdamW.
is_body_matrix = (
p.ndim == 2
and "embed" not in name
and "lm_head" not in name
)
(muon_params if is_body_matrix else adamw_params).append(p)
muon = Muon(muon_params, lr=muon_lr)
adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay)
n_muon = sum(p.numel() for p in muon_params)
n_adamw = sum(p.numel() for p in adamw_params)
print(f"[optim] Muon : {len(muon_params)} tensors, {n_muon:,} params")
print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params")
return muon, adamw
# --------------------------------------------------------------------------- #
# WSD learning-rate schedule
# --------------------------------------------------------------------------- #
def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100,
decay_frac: float = 0.2) -> float:
"""Warmup-Stable-Decay multiplier in [0, 1].
Linear warmup -> constant stable phase -> linear decay to ~0 over the final
`decay_frac` of training. Multiply each optimizer's base lr by this value.
"""
decay_start = int(total_steps * (1 - decay_frac))
if step < warmup:
return step / max(1, warmup)
if step < decay_start:
return 1.0
# Linear decay over the final decay_frac of steps.
progress = (step - decay_start) / max(1, total_steps - decay_start)
return max(0.0, 1.0 - progress)
# --------------------------------------------------------------------------- #
# Self-test
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
# 1) Newton-Schulz should produce a near-orthogonal matrix.
torch.manual_seed(0)
G = torch.randn(384, 1024)
Q = zeropower_via_newtonschulz5(G).float()
# For a wide matrix, Q @ Q.T should be close to identity.
I = Q @ Q.T
err = (I - torch.eye(Q.size(0))).abs().mean().item()
print(f"[ns] orthogonality error (lower=better): {err:.4f}")
# 2) WSD schedule shape.
total = 6000
pts = [0, 50, 100, 1000, 4800, 5400, 5999]
print("[wsd] step -> lr_mult")
for s in pts:
print(f" {s:>5} -> {wsd_lr_multiplier(s, total):.3f}")
|