ereniko commited on
Commit
0d72706
·
verified ·
1 Parent(s): b792941

Upload muon.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. muon.py +151 -0
muon.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer for İvme — MomentUm Orthogonalized by Newton-schulz.
3
+
4
+ Muon orthogonalizes each 2D weight's momentum-smoothed gradient via a quintic
5
+ Newton-Schulz iteration before the update. It consistently beats AdamW on
6
+ transformer bodies, especially on reasoning-heavy benchmarks, at negligible
7
+ extra cost.
8
+
9
+ Standard practice (and what we use here): Muon for the 2D transformer matrices,
10
+ AdamW for everything else — embeddings, the LM head, and all 1D params (norms).
11
+ Since İvme ties its embeddings, the shared embed/head table goes to AdamW.
12
+
13
+ Reference: Keller Jordan's Muon (github.com/KellerJordan/Muon).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ from torch.optim import AdamW
20
+
21
+
22
+ # --------------------------------------------------------------------------- #
23
+ # Newton-Schulz orthogonalization
24
+ # --------------------------------------------------------------------------- #
25
+ @torch.no_grad()
26
+ def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
27
+ """Compute an approximate orthogonalization of G via a quintic NS iteration.
28
+
29
+ The coefficients (a, b, c) are tuned so the iteration pushes the singular
30
+ values of G toward 1 without ever computing an SVD. Runs in bf16 for speed.
31
+ """
32
+ assert G.ndim == 2
33
+ a, b, c = (3.4445, -4.7750, 2.0315)
34
+ X = G.bfloat16()
35
+ transposed = G.size(0) > G.size(1)
36
+ if transposed:
37
+ X = X.T
38
+ # Normalize so the spectral norm is <= 1 before iterating.
39
+ X = X / (X.norm() + 1e-7)
40
+ for _ in range(steps):
41
+ A = X @ X.T
42
+ B = b * A + c * (A @ A)
43
+ X = a * X + B @ X
44
+ if transposed:
45
+ X = X.T
46
+ return X
47
+
48
+
49
+ # --------------------------------------------------------------------------- #
50
+ # Muon
51
+ # --------------------------------------------------------------------------- #
52
+ class Muon(torch.optim.Optimizer):
53
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
54
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
55
+ super().__init__(params, defaults)
56
+
57
+ @torch.no_grad()
58
+ def step(self):
59
+ for group in self.param_groups:
60
+ lr = group["lr"]
61
+ momentum = group["momentum"]
62
+ nesterov = group["nesterov"]
63
+ ns_steps = group["ns_steps"]
64
+ for p in group["params"]:
65
+ if p.grad is None:
66
+ continue
67
+ g = p.grad
68
+ state = self.state[p]
69
+ if "momentum_buffer" not in state:
70
+ state["momentum_buffer"] = torch.zeros_like(g)
71
+ buf = state["momentum_buffer"]
72
+ buf.mul_(momentum).add_(g)
73
+ g = g.add(buf, alpha=momentum) if nesterov else buf
74
+ g = zeropower_via_newtonschulz5(g, steps=ns_steps)
75
+ # Scale so the update RMS roughly matches the parameter shape;
76
+ # an orthogonalized matrix has spectral norm ~1 regardless of size.
77
+ scale = max(1.0, g.size(0) / g.size(1)) ** 0.5
78
+ p.add_(g.to(p.dtype), alpha=-lr * scale)
79
+
80
+
81
+ # --------------------------------------------------------------------------- #
82
+ # Hybrid optimizer builder
83
+ # --------------------------------------------------------------------------- #
84
+ def build_optimizers(model, muon_lr=0.02, adamw_lr=3e-4, weight_decay=0.1,
85
+ betas=(0.9, 0.95)):
86
+ """Split İvme's params into Muon (2D transformer matrices) and AdamW (rest).
87
+
88
+ Returns (muon, adamw). Step both each iteration; schedule both together.
89
+ """
90
+ muon_params, adamw_params = [], []
91
+ for name, p in model.named_parameters():
92
+ if not p.requires_grad:
93
+ continue
94
+ # 2D weights inside transformer blocks -> Muon.
95
+ # Embeddings, LM head, and all 1D params (norms) -> AdamW.
96
+ is_body_matrix = (
97
+ p.ndim == 2
98
+ and "embed" not in name
99
+ and "lm_head" not in name
100
+ )
101
+ (muon_params if is_body_matrix else adamw_params).append(p)
102
+
103
+ muon = Muon(muon_params, lr=muon_lr)
104
+ adamw = AdamW(adamw_params, lr=adamw_lr, betas=betas, weight_decay=weight_decay)
105
+
106
+ n_muon = sum(p.numel() for p in muon_params)
107
+ n_adamw = sum(p.numel() for p in adamw_params)
108
+ print(f"[optim] Muon : {len(muon_params)} tensors, {n_muon:,} params")
109
+ print(f"[optim] AdamW : {len(adamw_params)} tensors, {n_adamw:,} params")
110
+ return muon, adamw
111
+
112
+
113
+ # --------------------------------------------------------------------------- #
114
+ # WSD learning-rate schedule
115
+ # --------------------------------------------------------------------------- #
116
+ def wsd_lr_multiplier(step: int, total_steps: int, warmup: int = 100,
117
+ decay_frac: float = 0.2) -> float:
118
+ """Warmup-Stable-Decay multiplier in [0, 1].
119
+
120
+ Linear warmup -> constant stable phase -> linear decay to ~0 over the final
121
+ `decay_frac` of training. Multiply each optimizer's base lr by this value.
122
+ """
123
+ decay_start = int(total_steps * (1 - decay_frac))
124
+ if step < warmup:
125
+ return step / max(1, warmup)
126
+ if step < decay_start:
127
+ return 1.0
128
+ # Linear decay over the final decay_frac of steps.
129
+ progress = (step - decay_start) / max(1, total_steps - decay_start)
130
+ return max(0.0, 1.0 - progress)
131
+
132
+
133
+ # --------------------------------------------------------------------------- #
134
+ # Self-test
135
+ # --------------------------------------------------------------------------- #
136
+ if __name__ == "__main__":
137
+ # 1) Newton-Schulz should produce a near-orthogonal matrix.
138
+ torch.manual_seed(0)
139
+ G = torch.randn(384, 1024)
140
+ Q = zeropower_via_newtonschulz5(G).float()
141
+ # For a wide matrix, Q @ Q.T should be close to identity.
142
+ I = Q @ Q.T
143
+ err = (I - torch.eye(Q.size(0))).abs().mean().item()
144
+ print(f"[ns] orthogonality error (lower=better): {err:.4f}")
145
+
146
+ # 2) WSD schedule shape.
147
+ total = 6000
148
+ pts = [0, 50, 100, 1000, 4800, 5400, 5999]
149
+ print("[wsd] step -> lr_mult")
150
+ for s in pts:
151
+ print(f" {s:>5} -> {wsd_lr_multiplier(s, total):.3f}")