File size: 6,357 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""Muon — Momentum-Updated Newton-Schulz orthogonalised optimiser.

Jordan, Bernstein et al. (Oct 2024). Used to train Kimi K2 (1T MoE,
15.5T tokens, zero instabilities) — but Kimi K2 used MuonClip (the
QK-rescaling stability fix) on top. This implementation omits QK-Clip
since at sub-frontier scale plain Muon is empirically stable.

The core idea: SGD's momentum update (m = mu * m + g; W <- W - lr * m)
is fine, except it can leave m anisotropic — concentrated on the top
singular directions. Muon orthogonalises m via a few Newton-Schulz
iterations before applying it, so each step contributes equally across
all singular directions.

Algorithm (per 2D weight matrix, applied only to weights with ndim >= 2):
    1. m_t = momentum * m_{t-1} + g_t
    2. u_t = NewtonSchulz5(m_t)          # orthogonalise: u_t ≈ m_t @ (m_t^T m_t)^{-1/2}
    3. W_t = W_{t-1} - lr * sqrt(max(d_in, d_out) / d_min) * u_t

For 1D parameters (biases, norm scales, embeddings) Muon is *not*
recommended — fall back to AdamW for those. The convention in the
Muon papers is to declare two parameter groups: 2D-weights -> Muon,
everything-else -> AdamW. We follow that here.

Reference: https://kellerjordan.github.io/posts/muon/
"""
from __future__ import annotations

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer


@torch.no_grad()
def _newton_schulz5(g: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor:
    """Approximate g @ (g^T g)^{-1/2} via 5 Newton-Schulz iterations.

    Constants from the Muon reference implementation; tuned so that the
    iteration converges to the correct orthogonalisation in <=5 steps for
    typical weight-matrix singular-value distributions.
    """
    a, b, c = (3.4445, -4.7750, 2.0315)
    x = g.float()
    if g.size(-2) > g.size(-1):
        # Newton-Schulz expects "tall" matrix; transpose then transpose back.
        x = x.transpose(-2, -1)
        transposed = True
    else:
        transposed = False
    x = x / (x.norm() + eps)  # ||x|| = 1 entering the iteration
    for _ in range(steps):
        y = x @ x.transpose(-2, -1)
        x = a * x + b * y @ x + c * y @ y @ x
    if transposed:
        x = x.transpose(-2, -1)
    return x.to(g.dtype)


class Muon(Optimizer):
    """Muon optimiser for 2D+ parameters; pair with AdamW for 1D params.

    Parameters
    ----------
    params : iterable of 2D+ tensors only.
    lr : float, default 0.02. Larger than AdamW because the orthogonalised
         update has unit operator-norm, not unit element-norm.
    momentum : float, default 0.95.
    weight_decay : float, default 0.0.
    nesterov : bool, default True. Nesterov-flavoured momentum lookahead.
    ns_steps : int, default 5. Number of Newton-Schulz iterations.
    """

    def __init__(
        self,
        params,
        lr: float = 0.02,
        momentum: float = 0.95,
        weight_decay: float = 0.0,
        nesterov: bool = True,
        ns_steps: int = 5,
    ) -> None:
        if lr <= 0.0:
            raise ValueError(f"lr must be positive, got {lr}")
        if not 0.0 <= momentum < 1.0:
            raise ValueError(f"momentum must be in [0, 1), got {momentum}")
        defaults = dict(
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            nesterov=nesterov,
            ns_steps=ns_steps,
        )
        super().__init__(params, defaults)
        for group in self.param_groups:
            for p in group["params"]:
                if p.dim() < 2:
                    raise ValueError(
                        f"Muon expects 2D+ parameters; got shape {tuple(p.shape)}. "
                        "Pair Muon with AdamW for 1D params (biases, norms)."
                    )

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            mom = group["momentum"]
            wd = group["weight_decay"]
            nesterov = group["nesterov"]
            ns_steps = group["ns_steps"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                g = p.grad

                state = self.state[p]
                if "m" not in state:
                    state["m"] = torch.zeros_like(p)
                m = state["m"]
                m.mul_(mom).add_(g)
                update = m.add(g, alpha=mom) if nesterov else m

                # Newton-Schulz orthogonalisation; flatten any 3D+ into 2D first.
                orig_shape = update.shape
                if update.dim() > 2:
                    update_2d = update.reshape(update.size(0), -1)
                else:
                    update_2d = update
                u = _newton_schulz5(update_2d, steps=ns_steps)
                u = u.reshape(orig_shape)

                # Shape-aware LR scaling: multiply by sqrt(max(fan_in, fan_out) / d_min).
                # Keeps the operator-norm step size constant across rectangular shapes.
                fan_max = max(p.size(0), p.size(-1))
                fan_min = min(p.size(0), p.size(-1))
                shape_scale = (fan_max / fan_min) ** 0.5

                if wd != 0.0:
                    p.mul_(1 - lr * wd)
                p.add_(u, alpha=-lr * shape_scale)

        return loss


def split_params_for_muon(model: torch.nn.Module
                          ) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter]]:
    """Split a model's parameters into (muon_params, adamw_params).

    Convention from the Muon paper: 2D+ weights -> Muon; biases, norm scales,
    embeddings, unembed -> AdamW. We treat embeddings and unembed (lm_head) as
    AdamW-managed because their geometry (token-shaped, sparse gradients) is
    poorly suited to orthogonalisation.
    """
    muon_params: list[torch.nn.Parameter] = []
    adamw_params: list[torch.nn.Parameter] = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_embedding = ("embed" in name) or ("unembed" in name) or ("tok_embed" in name)
        if p.dim() >= 2 and not is_embedding:
            muon_params.append(p)
        else:
            adamw_params.append(p)
    return muon_params, adamw_params