| """MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else). |
| |
| Extracted verbatim from train.py (W1 modularization). Semantics unchanged. |
| |
| F1-F15 state preserved: |
| - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each |
| step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)` |
| fresh. Persistent copies of param storage would be mutated between forward |
| passes (via lerp_/sub_ on stacked tensors that share storage with params), |
| triggering "modified in-place" errors on grad_accum=2 backwards. |
| - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact. |
| - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with |
| dynamic=True + mode="default" to avoid the step-17→18 cudagraphs |
| stream-capture deadlock. See .omc/muon_compile_bug.md for the full |
| investigation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
|
|
| import torch |
|
|
| |
| _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1" |
| _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_") |
|
|
|
|
| polar_express_coeffs = [ |
| (8.156554524902461, -22.48329292557795, 15.878769915207462), |
| (4.042929935166739, -2.808917465908714, 0.5000178451051316), |
| (3.8916678022926607, -2.772484153217685, 0.5060648178503393), |
| (3.285753657755655, -2.3681294933425376, 0.46449024233003106), |
| (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), |
| ] |
|
|
|
|
| def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): |
| |
| |
| grad = grad.to(p.dtype) |
| p.mul_(1 - lr_t * wd_t) |
| exp_avg.lerp_(grad, 1 - beta1_t) |
| exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) |
| bias1 = 1 - beta1_t ** step_t |
| bias2 = 1 - beta2_t ** step_t |
| denom = (exp_avg_sq / bias2).sqrt() + eps_t |
| step_size = lr_t / bias1 |
| p.add_(exp_avg / denom, alpha=-step_size) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1" |
|
|
| def _maybe_compile(fn): |
| if _MUON_COMPILE: |
| |
| |
| |
| |
| return torch.compile(fn, fullgraph=False, dynamic=True, mode="default") |
| return fn |
|
|
| @_maybe_compile |
| def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, |
| momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): |
| |
| |
| |
| |
| stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone() |
| |
| momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype) |
| momentum_buffer.lerp_(stacked_grads, 1 - momentum) |
| g = stacked_grads.lerp_(momentum_buffer, momentum) |
| |
| X = g.bfloat16() |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) |
| if g.size(-2) > g.size(-1): |
| for a, b, c in polar_express_coeffs[:ns_steps]: |
| A = X.mT @ X |
| B = b * A + c * (A @ A) |
| X = a * X + X @ B |
| else: |
| for a, b, c in polar_express_coeffs[:ns_steps]: |
| A = X @ X.mT |
| B = b * A + c * (A @ A) |
| X = a * X + B @ X |
| g = X |
| |
| |
| |
| beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype) |
| v_mean = g.float().square().mean(dim=red_dim, keepdim=True) |
| red_dim_size = g.size(red_dim) |
| v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size |
| v_norm = v_norm_sq.sqrt() |
| second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) |
| step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() |
| scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() |
| v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() |
| final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) |
| g = g * final_scale.to(g.dtype) |
| |
| lr = lr_t.to(device=stacked_params.device, dtype=g.dtype) |
| wd = wd_t.to(device=stacked_params.device, dtype=g.dtype) |
| mask = (g * stacked_params) >= 0 |
| stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) |
|
|
|
|
| class MuonAdamW(torch.optim.Optimizer): |
| """Combined optimizer: Muon for 2D matrix params, AdamW for others.""" |
|
|
| def __init__(self, param_groups): |
| super().__init__(param_groups, defaults={}) |
| |
| self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
| self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") |
|
|
| def _step_adamw(self, group): |
| params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], [] |
| for p in group['params']: |
| if p.grad is None: |
| continue |
| state = self.state[p] |
| if not state: |
| state['step'] = 0 |
| state['exp_avg'] = torch.zeros_like(p) |
| state['exp_avg_sq'] = torch.zeros_like(p) |
| if 'step_t' not in state: |
| |
| state['step_t'] = torch.tensor( |
| float(state['step']), dtype=torch.float32, device=p.device |
| ) |
| state['step'] += 1 |
| params.append(p) |
| grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad) |
| exp_avgs.append(state['exp_avg']) |
| exp_avg_sqs.append(state['exp_avg_sq']) |
| state_steps.append(state['step_t']) |
|
|
| if not params: |
| return |
|
|
| if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda: |
| |
| |
| |
| buckets = {} |
| for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps): |
| key = (p.device, p.dtype) |
| buckets.setdefault(key, ([], [], [], [], [])) |
| b_p, b_g, b_ea, b_es, b_st = buckets[key] |
| b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st) |
|
|
| lr_f = float(group['lr']) |
| b1_f = float(group['betas'][0]) |
| b2_f = float(group['betas'][1]) |
| wd_f = float(group['weight_decay']) |
| eps_f = float(group['eps']) |
| for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items(): |
| torch._foreach_add_(b_st, 1.0) |
| torch._fused_adamw_( |
| b_p, b_g, b_ea, b_es, |
| [], |
| b_st, |
| amsgrad=False, |
| lr=lr_f, beta1=b1_f, beta2=b2_f, |
| weight_decay=wd_f, eps=eps_f, |
| maximize=False, |
| grad_scale=None, found_inf=None, |
| ) |
| return |
|
|
| |
| self._adamw_lr_t.fill_(group['lr']) |
| self._adamw_beta1_t.fill_(group['betas'][0]) |
| self._adamw_beta2_t.fill_(group['betas'][1]) |
| self._adamw_eps_t.fill_(group['eps']) |
| self._adamw_wd_t.fill_(group['weight_decay']) |
| for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs): |
| self._adamw_step_t.fill_(self.state[p]['step']) |
| adamw_step_fused(p, grad, exp_avg, exp_avg_sq, |
| self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, |
| self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) |
|
|
| def _step_muon(self, group): |
| params = [p for p in group['params'] if p.grad is not None] |
| if not params: |
| return |
| p = params[0] |
| state = self.state[p] |
| num_params = len(params) |
| shape, device, dtype = p.shape, p.device, p.dtype |
| if "momentum_buffer" not in state: |
| state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) |
| red_dim = -1 if shape[-2] >= shape[-1] else -2 |
| if "second_momentum_buffer" not in state: |
| |
| full_shape = (num_params, *shape) |
| state_shape = list(full_shape) |
| state_shape[len(state_shape) + red_dim] = 1 |
| state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) |
| |
| |
| stacked_grads = torch.stack([p.grad for p in params]) |
| stacked_params = torch.stack(params) |
| self._muon_momentum_t.fill_(group["momentum"]) |
| self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) |
| self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5) |
| self._muon_wd_t.fill_(group["weight_decay"]) |
| muon_step_fused(stacked_grads, stacked_params, |
| state["momentum_buffer"], state["second_momentum_buffer"], |
| self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, |
| self._muon_beta2_t, group["ns_steps"], red_dim) |
| torch._foreach_copy_(params, list(stacked_params.unbind(0))) |
|
|
| @torch.no_grad() |
| def step(self): |
| for group in self.param_groups: |
| kind = group['kind'] |
| |
| |
| |
| |
| |
| if kind == 'adamw' or kind == 'dt_bias' or kind == 'retina_contrastive': |
| self._step_adamw(group) |
| elif kind == 'muon': |
| self._step_muon(group) |
|
|