File size: 16,306 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c475135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c383594
 
c475135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c383594
 
 
c475135
c383594
c475135
 
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c475135
 
 
 
 
 
 
 
 
 
 
 
c383594
 
 
 
 
 
c475135
 
 
 
 
 
 
 
c383594
c475135
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""MuonAdamW optimizer β€” combined Muon (2D matrices) + AdamW (everything else).

Extracted verbatim from train.py (W1 modularization). Semantics unchanged.

F1-F15 state preserved:
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED β€” each
  step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
  fresh. Persistent copies of param storage would be mutated between forward
  passes (via lerp_/sub_ on stacked tensors that share storage with params),
  triggering "modified in-place" errors on grad_accum=2 backwards.
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
  dynamic=True + mode="default" to avoid the step-17β†’18 cudagraphs
  stream-capture deadlock. See .omc/muon_compile_bug.md for the full
  investigation.
"""

from __future__ import annotations

import os

import torch

# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")


polar_express_coeffs = [
    (8.156554524902461, -22.48329292557795, 15.878769915207462),
    (4.042929935166739, -2.808917465908714, 0.5000178451051316),
    (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
    (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
    (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]


def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
    # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
    # for the whole group) driven from MuonAdamW._step_adamw below.
    grad = grad.to(p.dtype)  # handle mixed bf16/fp32 from autocast
    p.mul_(1 - lr_t * wd_t)
    exp_avg.lerp_(grad, 1 - beta1_t)
    exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
    bias1 = 1 - beta1_t ** step_t
    bias2 = 1 - beta2_t ** step_t
    denom = (exp_avg_sq / bias2).sqrt() + eps_t
    step_size = lr_t / bias1
    p.add_(exp_avg / denom, alpha=-step_size)


# ---------------------------------------------------------------------------
# F15 muon_step_fused compile strategy.
#
# HYDRA_MUON_COMPILE env gate:
#   "1" (default ON) β€” wrap with torch.compile(dynamic=True, mode="default").
#       Dynamic=True collapses the per-shape specialization cache so that N
#       Muon param-groups with N distinct shapes trigger 1 compile, not N.
#       mode="default" keeps the inductor codegen but disables cudagraphs,
#       which is what caused the step-17β†’18 silent deadlock observed under
#       the original dynamic=False configuration: cudagraph stream capture
#       can deadlock against HTM's CUDA kernels running on the default
#       stream, and the failure mode at capture-time is a silent hang
#       (100% GPU util, no log output, process state R).
#   "0" β€” fall back to eager Python (slower, ~43k tps vs ~63k compiled).
#       Keeps an escape hatch in case a future torch/inductor regression
#       reintroduces a deadlock.
#
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
# alias-analysis edge case where inductor sees `g is stacked_grads` and
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
# ---------------------------------------------------------------------------
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"

def _maybe_compile(fn):
    if _MUON_COMPILE:
        # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
        # would enable) to avoid stream-capture deadlocks against HTM's CUDA
        # kernels. dynamic=True minimizes recompile count across param-group
        # shapes.
        return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
    return fn

@_maybe_compile
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
                    momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
    # Cast grads to param dtype AND clone defensively to break any alias
    # between the (freshly-stacked) input and the in-place lerp_ below.
    # Without this, inductor's alias analysis can emit code that reads from
    # post-mutation storage when computing `v_mean = g.square().mean(...)`.
    stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
    # Nesterov momentum
    momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
    momentum_buffer.lerp_(stacked_grads, 1 - momentum)
    g = stacked_grads.lerp_(momentum_buffer, momentum)
    # Polar express orthogonalization
    X = g.bfloat16()
    X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
    if g.size(-2) > g.size(-1):
        for a, b, c in polar_express_coeffs[:ns_steps]:
            A = X.mT @ X
            B = b * A + c * (A @ A)
            X = a * X + X @ B
    else:
        for a, b, c in polar_express_coeffs[:ns_steps]:
            A = X @ X.mT
            B = b * A + c * (A @ A)
            X = a * X + B @ X
    g = X
    # NorMuon variance reduction
    # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
    # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
    beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
    v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
    red_dim_size = g.size(red_dim)
    v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
    v_norm = v_norm_sq.sqrt()
    second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
    step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
    scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
    v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
    final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
    g = g * final_scale.to(g.dtype)
    # Cautious weight decay + parameter update
    lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
    wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
    mask = (g * stacked_params) >= 0
    stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)


class MuonAdamW(torch.optim.Optimizer):
    """Combined optimizer: Muon for 2D matrix params, AdamW for others."""

    def __init__(self, param_groups):
        super().__init__(param_groups, defaults={})
        # 0-D CPU tensors to avoid torch.compile recompilation when values change
        self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_bucket_caches = {}
        self._muon_params_caches = {}

    def state_dict(self):
        sd = super().state_dict()
        # Transient fused-step caches and device step_t tensors must not enter
        # checkpoints. step_t is recreated from scalar state['step'] lazily.
        for st in sd.get("state", {}).values():
            st.pop("step_t", None)
        for group in sd.get("param_groups", []):
            group.pop("_adamw_bucket_cache", None)
            group.pop("_muon_params_cache", None)
        return sd

    def load_state_dict(self, state_dict):
        for st in state_dict.get("state", {}).values():
            st.pop("step_t", None)
        for group in state_dict.get("param_groups", []):
            group.pop("_adamw_bucket_cache", None)
            group.pop("_muon_params_cache", None)
        self._adamw_bucket_caches.clear()
        self._muon_params_caches.clear()
        return super().load_state_dict(state_dict)

    def _ensure_adamw_state(self, p):
        state = self.state[p]
        if not state:
            state['step'] = 0
            state['exp_avg'] = torch.zeros_like(p)
            state['exp_avg_sq'] = torch.zeros_like(p)
        if 'step_t' not in state:
            # _fused_adamw_ wants a per-param float step tensor on-device.
            state['step_t'] = torch.tensor(
                float(state['step']), dtype=torch.float32, device=p.device
            )
        return state

    def _adamw_cached_buckets(self, group):
        """Return stable (device,dtype) param buckets for fused AdamW.

        Cache topology only. Optimizer state remains lazy for grad-bearing
        params so unused/frozen tensors do not bloat checkpoints.
        """
        params_tuple = tuple(group['params'])
        cache = self._adamw_bucket_caches.get(id(group))
        if cache is not None and cache.get('params_tuple') == params_tuple:
            return cache['buckets']

        buckets = {}
        for p in params_tuple:
            key = (p.device, p.dtype)
            buckets.setdefault(key, {'params': []})
            buckets[key]['params'].append(p)
        self._adamw_bucket_caches[id(group)] = {'params_tuple': params_tuple, 'buckets': buckets}
        return buckets

    def _step_adamw(self, group):
        if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW:
            # Mixed CPU/CUDA groups are unusual in Feather but skipping CPU
            # grads would be a correctness bug; disable fused path in that case.
            if not any(p.grad is not None and not p.is_cuda for p in group['params']):
                buckets = self._adamw_cached_buckets(group)
                lr_f = float(group['lr'])
                b1_f = float(group['betas'][0])
                b2_f = float(group['betas'][1])
                wd_f = float(group['weight_decay'])
                eps_f = float(group['eps'])
                launched = False
                for (_dev, _dt), bucket in buckets.items():
                    b_p = [p for p in bucket['params'] if p.grad is not None]
                    if not b_p or not b_p[0].is_cuda:
                        continue
                    b_g = [p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad for p in b_p]
                    b_ea, b_es, b_st = [], [], []
                    for p in b_p:
                        state = self._ensure_adamw_state(p)
                        state['step'] += 1
                        b_ea.append(state['exp_avg'])
                        b_es.append(state['exp_avg_sq'])
                        b_st.append(state['step_t'])
                    torch._foreach_add_(b_st, 1.0)
                    torch._fused_adamw_(
                        b_p, b_g, b_ea, b_es,
                        [],  # max_exp_avg_sqs unused (amsgrad=False)
                        b_st,
                        amsgrad=False,
                        lr=lr_f, beta1=b1_f, beta2=b2_f,
                        weight_decay=wd_f, eps=eps_f,
                        maximize=False,
                        grad_scale=None, found_inf=None,
                    )
                    launched = True
                if launched:
                    return

        params, grads, exp_avgs, exp_avg_sqs = [], [], [], []
        for p in group['params']:
            if p.grad is None:
                continue
            state = self._ensure_adamw_state(p)
            state['step'] += 1
            if 'step_t' in state:
                state['step_t'].fill_(float(state['step']))
            params.append(p)
            grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
            exp_avgs.append(state['exp_avg'])
            exp_avg_sqs.append(state['exp_avg_sq'])

        if not params:
            return

        # Fallback per-param path.
        self._adamw_lr_t.fill_(group['lr'])
        self._adamw_beta1_t.fill_(group['betas'][0])
        self._adamw_beta2_t.fill_(group['betas'][1])
        self._adamw_eps_t.fill_(group['eps'])
        self._adamw_wd_t.fill_(group['weight_decay'])
        for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
            self._adamw_step_t.fill_(self.state[p]['step'])
            adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
                             self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
                             self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)

    def _step_muon(self, group):
        params_tuple = tuple(group['params'])
        cache = self._muon_params_caches.get(id(group))
        if cache is None or cache.get('params_tuple') != params_tuple:
            cache = {'params_tuple': params_tuple, 'params': list(params_tuple)}
            self._muon_params_caches[id(group)] = cache
        params_all = cache['params']
        # Common Feather path: all Muon matrix params receive grads every step.
        # Preserve sparse/None-grad correctness by filtering only when needed.
        if all(p.grad is not None for p in params_all):
            params = params_all
        else:
            params = [p for p in params_all if p.grad is not None]
        if not params:
            return
        p = params[0]
        state = self.state[p]
        num_params = len(params)
        shape, device, dtype = p.shape, p.device, p.dtype
        if (
            "momentum_buffer" not in state
            or state["momentum_buffer"].shape[0] != num_params
            or tuple(state["momentum_buffer"].shape[1:]) != tuple(shape)
        ):
            # If grad-bearing Muon params change (rare; usually all matrix params
            # have grads), resize instead of crashing compiled Muon on a stale
            # leading dimension. This preserves skip-None-grad semantics.
            state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
            state.pop("second_momentum_buffer", None)
        red_dim = -1 if shape[-2] >= shape[-1] else -2
        if "second_momentum_buffer" not in state:
            # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
            full_shape = (num_params, *shape)
            state_shape = list(full_shape)
            state_shape[len(state_shape) + red_dim] = 1  # red_dim is negative
            state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
        # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
        # This was the autograd-safety fix that unblocks grad_accum>=2.
        stacked_grads = torch.stack([p.grad for p in params])
        stacked_params = torch.stack(params)
        self._muon_momentum_t.fill_(group["momentum"])
        self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
        self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
        self._muon_wd_t.fill_(group["weight_decay"])
        muon_step_fused(stacked_grads, stacked_params,
                        state["momentum_buffer"], state["second_momentum_buffer"],
                        self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
                        self._muon_beta2_t, group["ns_steps"], red_dim)
        torch._foreach_copy_(params, list(stacked_params.unbind(0)))

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            kind = group['kind']
            # Audit 2026-05-09 issue #19: 'dt_bias' (and the pre-existing
            # 'retina_contrastive') are AdamW-style groups with their own
            # cosine-LR exemption upstream. Route them through _step_adamw
            # so they actually update; the exempt-from-cosine treatment is
            # applied in training.py where group['lr'] is set.
            if kind == 'adamw' or kind == 'dt_bias' or kind == 'retina_contrastive':
                self._step_adamw(group)
            elif kind == 'muon':
                self._step_muon(group)