File size: 6,077 Bytes
0d72706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz.

Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic
Newton-Schulz iteration before the update. It consistently beats AdamW on
transformer bodies, especially on reasoning-heavy benchmarks, at negligible
extra cost.

Standard practice (and what we use here): Muon for the 2D transformer matrices,
AdamW for everything else — embeddings, the LM head, and all 1D params (norms).
Since İvme ties its embeddings, the shared embed/head table goes to AdamW.

Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon).
"""

from __future__ import annotations

import torch
from torch.optim import AdamW


# --------------------------------------------------------------------------- #
# Newton-Schulz orthogonalization
# --------------------------------------------------------------------------- #
@torch.no_grad()
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """Compute an approximate orthogonalization of G via a quintic NS iteration.

    The coefficients (a, b, c) are tuned so the iteration pushes the singular
    values of G toward 1 without ever computing an SVD. Runs in bf16 for speed.
    """
    assert G.ndim == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    transposed = G.size(0) > G.size(1)
    if transposed:
        X = X.T
    # Normalize so the spectral norm is <= 1 before iterating.
    X = X / (X.norm() + 1e-7)
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X
    if transposed:
        X = X.T
    return X


# --------------------------------------------------------------------------- #
# Muon
# --------------------------------------------------------------------------- #
class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            nesterov = group["nesterov"]
            ns_steps = group["ns_steps"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                g = p.grad
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf = state["momentum_buffer"]
                buf.mul_(momentum).add_(g)
                g = g.add(buf, alpha=momentum) if nesterov else buf
                g = zeropower_via_newtonschulz5(g, steps=ns_steps)
                # Scale so the update RMS roughly matches the parameter shape;
                # an orthogonalized matrix has spectral norm ~1 regardless of size.
                scale = max(1.0, g.size(0) / g.size(1)) ** 0.5
                p.add_(g.to(p.dtype), alpha=-lr * scale)


# --------------------------------------------------------------------------- #
# Hybrid optimizer builder
# --------------------------------------------------------------------------- #
def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1,
                     betas=(0.9, 0.95)):
    """Split İvme's params into Muon (2D transformer matrices) and AdamW (rest).

    Returns (muon, adamw). Step both each iteration; schedule both together.
    """
    muon_params, adamw_params = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        # 2D weights inside transformer blocks -> Muon.
        # Embeddings, LM head, and all 1D params (norms) -> AdamW.
        is_body_matrix = (
            p.ndim == 2
            and "embed" not in name
            and "lm_head" not in name
        )
        (muon_params if is_body_matrix else adamw_params).append(p)

    muon = Muon(muon_params, lr=muon_lr)
    adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay)

    n_muon = sum(p.numel() for p in muon_params)
    n_adamw = sum(p.numel() for p in adamw_params)
    print(f"[optim] Muon  : {len(muon_params)} tensors, {n_muon:,} params")
    print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params")
    return muon, adamw


# --------------------------------------------------------------------------- #
# WSD learning-rate schedule
# --------------------------------------------------------------------------- #
def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100,
                      decay_frac: float = 0.2) -> float:
    """Warmup-Stable-Decay multiplier in [0, 1].

    Linear warmup -> constant stable phase -> linear decay to ~0 over the final
    `decay_frac` of training. Multiply each optimizer's base lr by this value.
    """
    decay_start = int(total_steps * (1 - decay_frac))
    if step < warmup:
        return step / max(1, warmup)
    if step < decay_start:
        return 1.0
    # Linear decay over the final decay_frac of steps.
    progress = (step - decay_start) / max(1, total_steps - decay_start)
    return max(0.0, 1.0 - progress)


# --------------------------------------------------------------------------- #
# Self-test
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    # 1) Newton-Schulz should produce a near-orthogonal matrix.
    torch.manual_seed(0)
    G = torch.randn(384, 1024)
    Q = zeropower_via_newtonschulz5(G).float()
    # For a wide matrix, Q @ Q.T should be close to identity.
    I = Q @ Q.T
    err = (I - torch.eye(Q.size(0))).abs().mean().item()
    print(f"[ns] orthogonality error (lower=better): {err:.4f}")

    # 2) WSD schedule shape.
    total = 6000
    pts = [0, 50, 100, 1000, 4800, 5400, 5999]
    print("[wsd] step -> lr_mult")
    for s in pts:
        print(f"      {s:>5} -> {wsd_lr_multiplier(s, total):.3f}")