ereniko's picture
Upload muon.py with huggingface_hub
0d72706 verified
"""
Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz.
Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic
Newton-Schulz iteration before the update. It consistently beats AdamW on
transformer bodies, especially on reasoning-heavy benchmarks, at negligible
extra cost.
Standard practice (and what we use here): Muon for the 2D transformer matrices,
AdamW for everything else — embeddings, the LM head, and all 1D params (norms).
Since İvme ties its embeddings, the shared embed/head table goes to AdamW.
Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon).
"""
from __future__ import annotations
import torch
from torch.optim import AdamW
# --------------------------------------------------------------------------- #
# Newton-Schulz orthogonalization
# --------------------------------------------------------------------------- #
@torch.no_grad()
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
"""Compute an approximate orthogonalization of G via a quintic NS iteration.
The coefficients (a, b, c) are tuned so the iteration pushes the singular
values of G toward 1 without ever computing an SVD. Runs in bf16 for speed.
"""
assert G.ndim == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
transposed = G.size(0) > G.size(1)
if transposed:
X = X.T
# Normalize so the spectral norm is <= 1 before iterating.
X = X / (X.norm() + 1e-7)
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
if transposed:
X = X.T
return X
# --------------------------------------------------------------------------- #
# Muon
# --------------------------------------------------------------------------- #
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
buf.mul_(momentum).add_(g)
g = g.add(buf, alpha=momentum) if nesterov else buf
g = zeropower_via_newtonschulz5(g, steps=ns_steps)
# Scale so the update RMS roughly matches the parameter shape;
# an orthogonalized matrix has spectral norm ~1 regardless of size.
scale = max(1.0, g.size(0) / g.size(1)) ** 0.5
p.add_(g.to(p.dtype), alpha=-lr * scale)
# --------------------------------------------------------------------------- #
# Hybrid optimizer builder
# --------------------------------------------------------------------------- #
def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1,
betas=(0.9, 0.95)):
"""Split İvme's params into Muon (2D transformer matrices) and AdamW (rest).
Returns (muon, adamw). Step both each iteration; schedule both together.
"""
muon_params, adamw_params = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
# 2D weights inside transformer blocks -> Muon.
# Embeddings, LM head, and all 1D params (norms) -> AdamW.
is_body_matrix = (
p.ndim == 2
and "embed" not in name
and "lm_head" not in name
)
(muon_params if is_body_matrix else adamw_params).append(p)
muon = Muon(muon_params, lr=muon_lr)
adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay)
n_muon = sum(p.numel() for p in muon_params)
n_adamw = sum(p.numel() for p in adamw_params)
print(f"[optim] Muon : {len(muon_params)} tensors, {n_muon:,} params")
print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params")
return muon, adamw
# --------------------------------------------------------------------------- #
# WSD learning-rate schedule
# --------------------------------------------------------------------------- #
def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100,
decay_frac: float = 0.2) -> float:
"""Warmup-Stable-Decay multiplier in [0, 1].
Linear warmup -> constant stable phase -> linear decay to ~0 over the final
`decay_frac` of training. Multiply each optimizer's base lr by this value.
"""
decay_start = int(total_steps * (1 - decay_frac))
if step < warmup:
return step / max(1, warmup)
if step < decay_start:
return 1.0
# Linear decay over the final decay_frac of steps.
progress = (step - decay_start) / max(1, total_steps - decay_start)
return max(0.0, 1.0 - progress)
# --------------------------------------------------------------------------- #
# Self-test
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
# 1) Newton-Schulz should produce a near-orthogonal matrix.
torch.manual_seed(0)
G = torch.randn(384, 1024)
Q = zeropower_via_newtonschulz5(G).float()
# For a wide matrix, Q @ Q.T should be close to identity.
I = Q @ Q.T
err = (I - torch.eye(Q.size(0))).abs().mean().item()
print(f"[ns] orthogonality error (lower=better): {err:.4f}")
# 2) WSD schedule shape.
total = 6000
pts = [0, 50, 100, 1000, 4800, 5400, 5999]
print("[wsd] step -> lr_mult")
for s in pts:
print(f" {s:>5} -> {wsd_lr_multiplier(s, total):.3f}")