File size: 6,357 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """Muon — Momentum-Updated Newton-Schulz orthogonalised optimiser.
Jordan, Bernstein et al. (Oct 2024). Used to train Kimi K2 (1T MoE,
15.5T tokens, zero instabilities) — but Kimi K2 used MuonClip (the
QK-rescaling stability fix) on top. This implementation omits QK-Clip
since at sub-frontier scale plain Muon is empirically stable.
The core idea: SGD's momentum update (m = mu * m + g; W <- W - lr * m)
is fine, except it can leave m anisotropic — concentrated on the top
singular directions. Muon orthogonalises m via a few Newton-Schulz
iterations before applying it, so each step contributes equally across
all singular directions.
Algorithm (per 2D weight matrix, applied only to weights with ndim >= 2):
1. m_t = momentum * m_{t-1} + g_t
2. u_t = NewtonSchulz5(m_t) # orthogonalise: u_t ≈ m_t @ (m_t^T m_t)^{-1/2}
3. W_t = W_{t-1} - lr * sqrt(max(d_in, d_out) / d_min) * u_t
For 1D parameters (biases, norm scales, embeddings) Muon is *not*
recommended — fall back to AdamW for those. The convention in the
Muon papers is to declare two parameter groups: 2D-weights -> Muon,
everything-else -> AdamW. We follow that here.
Reference: https://kellerjordan.github.io/posts/muon/
"""
from __future__ import annotations
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
@torch.no_grad()
def _newton_schulz5(g: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor:
"""Approximate g @ (g^T g)^{-1/2} via 5 Newton-Schulz iterations.
Constants from the Muon reference implementation; tuned so that the
iteration converges to the correct orthogonalisation in <=5 steps for
typical weight-matrix singular-value distributions.
"""
a, b, c = (3.4445, -4.7750, 2.0315)
x = g.float()
if g.size(-2) > g.size(-1):
# Newton-Schulz expects "tall" matrix; transpose then transpose back.
x = x.transpose(-2, -1)
transposed = True
else:
transposed = False
x = x / (x.norm() + eps) # ||x|| = 1 entering the iteration
for _ in range(steps):
y = x @ x.transpose(-2, -1)
x = a * x + b * y @ x + c * y @ y @ x
if transposed:
x = x.transpose(-2, -1)
return x.to(g.dtype)
class Muon(Optimizer):
"""Muon optimiser for 2D+ parameters; pair with AdamW for 1D params.
Parameters
----------
params : iterable of 2D+ tensors only.
lr : float, default 0.02. Larger than AdamW because the orthogonalised
update has unit operator-norm, not unit element-norm.
momentum : float, default 0.95.
weight_decay : float, default 0.0.
nesterov : bool, default True. Nesterov-flavoured momentum lookahead.
ns_steps : int, default 5. Number of Newton-Schulz iterations.
"""
def __init__(
self,
params,
lr: float = 0.02,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
) -> None:
if lr <= 0.0:
raise ValueError(f"lr must be positive, got {lr}")
if not 0.0 <= momentum < 1.0:
raise ValueError(f"momentum must be in [0, 1), got {momentum}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
if p.dim() < 2:
raise ValueError(
f"Muon expects 2D+ parameters; got shape {tuple(p.shape)}. "
"Pair Muon with AdamW for 1D params (biases, norms)."
)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
mom = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "m" not in state:
state["m"] = torch.zeros_like(p)
m = state["m"]
m.mul_(mom).add_(g)
update = m.add(g, alpha=mom) if nesterov else m
# Newton-Schulz orthogonalisation; flatten any 3D+ into 2D first.
orig_shape = update.shape
if update.dim() > 2:
update_2d = update.reshape(update.size(0), -1)
else:
update_2d = update
u = _newton_schulz5(update_2d, steps=ns_steps)
u = u.reshape(orig_shape)
# Shape-aware LR scaling: multiply by sqrt(max(fan_in, fan_out) / d_min).
# Keeps the operator-norm step size constant across rectangular shapes.
fan_max = max(p.size(0), p.size(-1))
fan_min = min(p.size(0), p.size(-1))
shape_scale = (fan_max / fan_min) ** 0.5
if wd != 0.0:
p.mul_(1 - lr * wd)
p.add_(u, alpha=-lr * shape_scale)
return loss
def split_params_for_muon(model: torch.nn.Module
) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter]]:
"""Split a model's parameters into (muon_params, adamw_params).
Convention from the Muon paper: 2D+ weights -> Muon; biases, norm scales,
embeddings, unembed -> AdamW. We treat embeddings and unembed (lm_head) as
AdamW-managed because their geometry (token-shaped, sparse gradients) is
poorly suited to orthogonalisation.
"""
muon_params: list[torch.nn.Parameter] = []
adamw_params: list[torch.nn.Parameter] = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
is_embedding = ("embed" in name) or ("unembed" in name) or ("tok_embed" in name)
if p.dim() >= 2 and not is_embedding:
muon_params.append(p)
else:
adamw_params.append(p)
return muon_params, adamw_params
|