File size: 12,509 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""MuonAdamW optimizer β€” combined Muon (2D matrices) + AdamW (everything else).

Extracted verbatim from train.py (W1 modularization). Semantics unchanged.

F1-F15 state preserved:
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED β€” each
  step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
  fresh. Persistent copies of param storage would be mutated between forward
  passes (via lerp_/sub_ on stacked tensors that share storage with params),
  triggering "modified in-place" errors on grad_accum=2 backwards.
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
  dynamic=True + mode="default" to avoid the step-17β†’18 cudagraphs
  stream-capture deadlock. See .omc/muon_compile_bug.md for the full
  investigation.
"""

from __future__ import annotations

import os

import torch

# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")


polar_express_coeffs = [
    (8.156554524902461, -22.48329292557795, 15.878769915207462),
    (4.042929935166739, -2.808917465908714, 0.5000178451051316),
    (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
    (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
    (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]


def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
    # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
    # for the whole group) driven from MuonAdamW._step_adamw below.
    grad = grad.to(p.dtype)  # handle mixed bf16/fp32 from autocast
    p.mul_(1 - lr_t * wd_t)
    exp_avg.lerp_(grad, 1 - beta1_t)
    exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
    bias1 = 1 - beta1_t ** step_t
    bias2 = 1 - beta2_t ** step_t
    denom = (exp_avg_sq / bias2).sqrt() + eps_t
    step_size = lr_t / bias1
    p.add_(exp_avg / denom, alpha=-step_size)


# ---------------------------------------------------------------------------
# F15 muon_step_fused compile strategy.
#
# HYDRA_MUON_COMPILE env gate:
#   "1" (default ON) β€” wrap with torch.compile(dynamic=True, mode="default").
#       Dynamic=True collapses the per-shape specialization cache so that N
#       Muon param-groups with N distinct shapes trigger 1 compile, not N.
#       mode="default" keeps the inductor codegen but disables cudagraphs,
#       which is what caused the step-17β†’18 silent deadlock observed under
#       the original dynamic=False configuration: cudagraph stream capture
#       can deadlock against HTM's CUDA kernels running on the default
#       stream, and the failure mode at capture-time is a silent hang
#       (100% GPU util, no log output, process state R).
#   "0" β€” fall back to eager Python (slower, ~43k tps vs ~63k compiled).
#       Keeps an escape hatch in case a future torch/inductor regression
#       reintroduces a deadlock.
#
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
# alias-analysis edge case where inductor sees `g is stacked_grads` and
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
# ---------------------------------------------------------------------------
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"

def _maybe_compile(fn):
    if _MUON_COMPILE:
        # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
        # would enable) to avoid stream-capture deadlocks against HTM's CUDA
        # kernels. dynamic=True minimizes recompile count across param-group
        # shapes.
        return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
    return fn

@_maybe_compile
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
                    momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
    # Cast grads to param dtype AND clone defensively to break any alias
    # between the (freshly-stacked) input and the in-place lerp_ below.
    # Without this, inductor's alias analysis can emit code that reads from
    # post-mutation storage when computing `v_mean = g.square().mean(...)`.
    stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
    # Nesterov momentum
    momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
    momentum_buffer.lerp_(stacked_grads, 1 - momentum)
    g = stacked_grads.lerp_(momentum_buffer, momentum)
    # Polar express orthogonalization
    X = g.bfloat16()
    X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
    if g.size(-2) > g.size(-1):
        for a, b, c in polar_express_coeffs[:ns_steps]:
            A = X.mT @ X
            B = b * A + c * (A @ A)
            X = a * X + X @ B
    else:
        for a, b, c in polar_express_coeffs[:ns_steps]:
            A = X @ X.mT
            B = b * A + c * (A @ A)
            X = a * X + B @ X
    g = X
    # NorMuon variance reduction
    # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
    # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
    beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
    v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
    red_dim_size = g.size(red_dim)
    v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
    v_norm = v_norm_sq.sqrt()
    second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
    step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
    scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
    v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
    final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
    g = g * final_scale.to(g.dtype)
    # Cautious weight decay + parameter update
    lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
    wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
    mask = (g * stacked_params) >= 0
    stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)


class MuonAdamW(torch.optim.Optimizer):
    """Combined optimizer: Muon for 2D matrix params, AdamW for others."""

    def __init__(self, param_groups):
        super().__init__(param_groups, defaults={})
        # 0-D CPU tensors to avoid torch.compile recompilation when values change
        self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
        self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")

    def _step_adamw(self, group):
        params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
        for p in group['params']:
            if p.grad is None:
                continue
            state = self.state[p]
            if not state:
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
            if 'step_t' not in state:
                # _fused_adamw_ wants a per-param float step tensor on-device.
                state['step_t'] = torch.tensor(
                    float(state['step']), dtype=torch.float32, device=p.device
                )
            state['step'] += 1
            params.append(p)
            grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
            exp_avgs.append(state['exp_avg'])
            exp_avg_sqs.append(state['exp_avg_sq'])
            state_steps.append(state['step_t'])

        if not params:
            return

        if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
            # _fused_adamw_ needs uniform (device, dtype) within a call, so
            # group by (device, dtype) β€” same pattern as PyTorch's own
            # AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
            buckets = {}
            for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
                key = (p.device, p.dtype)
                buckets.setdefault(key, ([], [], [], [], []))
                b_p, b_g, b_ea, b_es, b_st = buckets[key]
                b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)

            lr_f = float(group['lr'])
            b1_f = float(group['betas'][0])
            b2_f = float(group['betas'][1])
            wd_f = float(group['weight_decay'])
            eps_f = float(group['eps'])
            for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
                torch._foreach_add_(b_st, 1.0)
                torch._fused_adamw_(
                    b_p, b_g, b_ea, b_es,
                    [],  # max_exp_avg_sqs unused (amsgrad=False)
                    b_st,
                    amsgrad=False,
                    lr=lr_f, beta1=b1_f, beta2=b2_f,
                    weight_decay=wd_f, eps=eps_f,
                    maximize=False,
                    grad_scale=None, found_inf=None,
                )
            return

        # Fallback per-param path.
        self._adamw_lr_t.fill_(group['lr'])
        self._adamw_beta1_t.fill_(group['betas'][0])
        self._adamw_beta2_t.fill_(group['betas'][1])
        self._adamw_eps_t.fill_(group['eps'])
        self._adamw_wd_t.fill_(group['weight_decay'])
        for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
            self._adamw_step_t.fill_(self.state[p]['step'])
            adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
                             self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
                             self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)

    def _step_muon(self, group):
        params = [p for p in group['params'] if p.grad is not None]
        if not params:
            return
        p = params[0]
        state = self.state[p]
        num_params = len(params)
        shape, device, dtype = p.shape, p.device, p.dtype
        if "momentum_buffer" not in state:
            state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
        red_dim = -1 if shape[-2] >= shape[-1] else -2
        if "second_momentum_buffer" not in state:
            # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
            full_shape = (num_params, *shape)
            state_shape = list(full_shape)
            state_shape[len(state_shape) + red_dim] = 1  # red_dim is negative
            state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
        # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
        # This was the autograd-safety fix that unblocks grad_accum>=2.
        stacked_grads = torch.stack([p.grad for p in params])
        stacked_params = torch.stack(params)
        self._muon_momentum_t.fill_(group["momentum"])
        self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
        self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
        self._muon_wd_t.fill_(group["weight_decay"])
        muon_step_fused(stacked_grads, stacked_params,
                        state["momentum_buffer"], state["second_momentum_buffer"],
                        self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
                        self._muon_beta2_t, group["ns_steps"], red_dim)
        torch._foreach_copy_(params, list(stacked_params.unbind(0)))

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if group['kind'] == 'adamw':
                self._step_adamw(group)
            elif group['kind'] == 'muon':
                self._step_muon(group)