"""MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else). Extracted verbatim from train.py (W1 modularization). Semantics unchanged. F1-F15 state preserved: - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)` fresh. Persistent copies of param storage would be mutated between forward passes (via lerp_/sub_ on stacked tensors that share storage with params), triggering "modified in-place" errors on grad_accum=2 backwards. - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact. - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with dynamic=True + mode="default" to avoid the step-17→18 cudagraphs stream-capture deadlock. See .omc/muon_compile_bug.md for the full investigation. """ from __future__ import annotations import os import torch # HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel. _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1" _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_") polar_express_coeffs = [ (8.156554524902461, -22.48329292557795, 15.878769915207462), (4.042929935166739, -2.808917465908714, 0.5000178451051316), (3.8916678022926607, -2.772484153217685, 0.5060648178503393), (3.285753657755655, -2.3681294933425376, 0.46449024233003106), (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch # for the whole group) driven from MuonAdamW._step_adamw below. grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast p.mul_(1 - lr_t * wd_t) exp_avg.lerp_(grad, 1 - beta1_t) exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) bias1 = 1 - beta1_t ** step_t bias2 = 1 - beta2_t ** step_t denom = (exp_avg_sq / bias2).sqrt() + eps_t step_size = lr_t / bias1 p.add_(exp_avg / denom, alpha=-step_size) # --------------------------------------------------------------------------- # F15 muon_step_fused compile strategy. # # HYDRA_MUON_COMPILE env gate: # "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default"). # Dynamic=True collapses the per-shape specialization cache so that N # Muon param-groups with N distinct shapes trigger 1 compile, not N. # mode="default" keeps the inductor codegen but disables cudagraphs, # which is what caused the step-17→18 silent deadlock observed under # the original dynamic=False configuration: cudagraph stream capture # can deadlock against HTM's CUDA kernels running on the default # stream, and the failure mode at capture-time is a silent hang # (100% GPU util, no log output, process state R). # "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled). # Keeps an escape hatch in case a future torch/inductor regression # reintroduces a deadlock. # # Defensive .clone() on stacked_grads before in-place lerp_ eliminates the # alias-analysis edge case where inductor sees `g is stacked_grads` and # subsequent `stacked_grads.square()` operating on the post-lerp storage. # --------------------------------------------------------------------------- _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1" def _maybe_compile(fn): if _MUON_COMPILE: # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead # would enable) to avoid stream-capture deadlocks against HTM's CUDA # kernels. dynamic=True minimizes recompile count across param-group # shapes. return torch.compile(fn, fullgraph=False, dynamic=True, mode="default") return fn @_maybe_compile def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): # Cast grads to param dtype AND clone defensively to break any alias # between the (freshly-stacked) input and the in-place lerp_ below. # Without this, inductor's alias analysis can emit code that reads from # post-mutation storage when computing `v_mean = g.square().mean(...)`. stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone() # Nesterov momentum momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype) momentum_buffer.lerp_(stacked_grads, 1 - momentum) g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express orthogonalization X = g.bfloat16() X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) if g.size(-2) > g.size(-1): for a, b, c in polar_express_coeffs[:ns_steps]: A = X.mT @ X B = b * A + c * (A @ A) X = a * X + X @ B else: for a, b, c in polar_express_coeffs[:ns_steps]: A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X g = X # NorMuon variance reduction # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200. beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype) v_mean = g.float().square().mean(dim=red_dim, keepdim=True) red_dim_size = g.size(red_dim) v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size v_norm = v_norm_sq.sqrt() second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) g = g * final_scale.to(g.dtype) # Cautious weight decay + parameter update lr = lr_t.to(device=stacked_params.device, dtype=g.dtype) wd = wd_t.to(device=stacked_params.device, dtype=g.dtype) mask = (g * stacked_params) >= 0 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) class MuonAdamW(torch.optim.Optimizer): """Combined optimizer: Muon for 2D matrix params, AdamW for others.""" def __init__(self, param_groups): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _step_adamw(self, group): params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], [] for p in group['params']: if p.grad is None: continue state = self.state[p] if not state: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p) state['exp_avg_sq'] = torch.zeros_like(p) if 'step_t' not in state: # _fused_adamw_ wants a per-param float step tensor on-device. state['step_t'] = torch.tensor( float(state['step']), dtype=torch.float32, device=p.device ) state['step'] += 1 params.append(p) grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step_t']) if not params: return if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda: # _fused_adamw_ needs uniform (device, dtype) within a call, so # group by (device, dtype) — same pattern as PyTorch's own # AdamW(fused=True) path (_group_tensors_by_device_and_dtype). buckets = {} for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps): key = (p.device, p.dtype) buckets.setdefault(key, ([], [], [], [], [])) b_p, b_g, b_ea, b_es, b_st = buckets[key] b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st) lr_f = float(group['lr']) b1_f = float(group['betas'][0]) b2_f = float(group['betas'][1]) wd_f = float(group['weight_decay']) eps_f = float(group['eps']) for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items(): torch._foreach_add_(b_st, 1.0) torch._fused_adamw_( b_p, b_g, b_ea, b_es, [], # max_exp_avg_sqs unused (amsgrad=False) b_st, amsgrad=False, lr=lr_f, beta1=b1_f, beta2=b2_f, weight_decay=wd_f, eps=eps_f, maximize=False, grad_scale=None, found_inf=None, ) return # Fallback per-param path. self._adamw_lr_t.fill_(group['lr']) self._adamw_beta1_t.fill_(group['betas'][0]) self._adamw_beta2_t.fill_(group['betas'][1]) self._adamw_eps_t.fill_(group['eps']) self._adamw_wd_t.fill_(group['weight_decay']) for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs): self._adamw_step_t.fill_(self.state[p]['step']) adamw_step_fused(p, grad, exp_avg, exp_avg_sq, self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) def _step_muon(self, group): params = [p for p in group['params'] if p.grad is not None] if not params: return p = params[0] state = self.state[p] num_params = len(params) shape, device, dtype = p.shape, p.device, p.dtype if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) red_dim = -1 if shape[-2] >= shape[-1] else -2 if "second_momentum_buffer" not in state: # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True) full_shape = (num_params, *shape) state_shape = list(full_shape) state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf). # This was the autograd-safety fix that unblocks grad_accum>=2. stacked_grads = torch.stack([p.grad for p in params]) stacked_params = torch.stack(params) self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5) self._muon_wd_t.fill_(group["weight_decay"]) muon_step_fused(stacked_grads, stacked_params, state["momentum_buffer"], state["second_momentum_buffer"], self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t, group["ns_steps"], red_dim) torch._foreach_copy_(params, list(stacked_params.unbind(0))) @torch.no_grad() def step(self): for group in self.param_groups: kind = group['kind'] # Audit 2026-05-09 issue #19: 'dt_bias' (and the pre-existing # 'retina_contrastive') are AdamW-style groups with their own # cosine-LR exemption upstream. Route them through _step_adamw # so they actually update; the exempt-from-cosine treatment is # applied in training.py where group['lr'] is set. if kind == 'adamw' or kind == 'dt_bias' or kind == 'retina_contrastive': self._step_adamw(group) elif kind == 'muon': self._step_muon(group)