Spaces:
Runtime error
Runtime error
| """MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else). | |
| Extracted verbatim from train.py (W1 modularization). Semantics unchanged. | |
| F1-F15 state preserved: | |
| - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each | |
| step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)` | |
| fresh. Persistent copies of param storage would be mutated between forward | |
| passes (via lerp_/sub_ on stacked tensors that share storage with params), | |
| triggering "modified in-place" errors on grad_accum=2 backwards. | |
| - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact. | |
| - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with | |
| dynamic=True + mode="default" to avoid the step-17→18 cudagraphs | |
| stream-capture deadlock. See .omc/muon_compile_bug.md for the full | |
| investigation. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import torch | |
| # HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel. | |
| _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1" | |
| _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_") | |
| polar_express_coeffs = [ | |
| (8.156554524902461, -22.48329292557795, 15.878769915207462), | |
| (4.042929935166739, -2.808917465908714, 0.5000178451051316), | |
| (3.8916678022926607, -2.772484153217685, 0.5060648178503393), | |
| (3.285753657755655, -2.3681294933425376, 0.46449024233003106), | |
| (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), | |
| ] | |
| def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): | |
| # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch | |
| # for the whole group) driven from MuonAdamW._step_adamw below. | |
| grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast | |
| p.mul_(1 - lr_t * wd_t) | |
| exp_avg.lerp_(grad, 1 - beta1_t) | |
| exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) | |
| bias1 = 1 - beta1_t ** step_t | |
| bias2 = 1 - beta2_t ** step_t | |
| denom = (exp_avg_sq / bias2).sqrt() + eps_t | |
| step_size = lr_t / bias1 | |
| p.add_(exp_avg / denom, alpha=-step_size) | |
| # --------------------------------------------------------------------------- | |
| # F15 muon_step_fused compile strategy. | |
| # | |
| # HYDRA_MUON_COMPILE env gate: | |
| # "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default"). | |
| # Dynamic=True collapses the per-shape specialization cache so that N | |
| # Muon param-groups with N distinct shapes trigger 1 compile, not N. | |
| # mode="default" keeps the inductor codegen but disables cudagraphs, | |
| # which is what caused the step-17→18 silent deadlock observed under | |
| # the original dynamic=False configuration: cudagraph stream capture | |
| # can deadlock against HTM's CUDA kernels running on the default | |
| # stream, and the failure mode at capture-time is a silent hang | |
| # (100% GPU util, no log output, process state R). | |
| # "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled). | |
| # Keeps an escape hatch in case a future torch/inductor regression | |
| # reintroduces a deadlock. | |
| # | |
| # Defensive .clone() on stacked_grads before in-place lerp_ eliminates the | |
| # alias-analysis edge case where inductor sees `g is stacked_grads` and | |
| # subsequent `stacked_grads.square()` operating on the post-lerp storage. | |
| # --------------------------------------------------------------------------- | |
| _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1" | |
| def _maybe_compile(fn): | |
| if _MUON_COMPILE: | |
| # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead | |
| # would enable) to avoid stream-capture deadlocks against HTM's CUDA | |
| # kernels. dynamic=True minimizes recompile count across param-group | |
| # shapes. | |
| return torch.compile(fn, fullgraph=False, dynamic=True, mode="default") | |
| return fn | |
| def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, | |
| momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): | |
| # Cast grads to param dtype AND clone defensively to break any alias | |
| # between the (freshly-stacked) input and the in-place lerp_ below. | |
| # Without this, inductor's alias analysis can emit code that reads from | |
| # post-mutation storage when computing `v_mean = g.square().mean(...)`. | |
| stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone() | |
| # Nesterov momentum | |
| momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype) | |
| momentum_buffer.lerp_(stacked_grads, 1 - momentum) | |
| g = stacked_grads.lerp_(momentum_buffer, momentum) | |
| # Polar express orthogonalization | |
| X = g.bfloat16() | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) | |
| if g.size(-2) > g.size(-1): | |
| for a, b, c in polar_express_coeffs[:ns_steps]: | |
| A = X.mT @ X | |
| B = b * A + c * (A @ A) | |
| X = a * X + X @ B | |
| else: | |
| for a, b, c in polar_express_coeffs[:ns_steps]: | |
| A = X @ X.mT | |
| B = b * A + c * (A @ A) | |
| X = a * X + B @ X | |
| g = X | |
| # NorMuon variance reduction | |
| # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the | |
| # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200. | |
| beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype) | |
| v_mean = g.float().square().mean(dim=red_dim, keepdim=True) | |
| red_dim_size = g.size(red_dim) | |
| v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size | |
| v_norm = v_norm_sq.sqrt() | |
| second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) | |
| step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() | |
| scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() | |
| v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() | |
| final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) | |
| g = g * final_scale.to(g.dtype) | |
| # Cautious weight decay + parameter update | |
| lr = lr_t.to(device=stacked_params.device, dtype=g.dtype) | |
| wd = wd_t.to(device=stacked_params.device, dtype=g.dtype) | |
| mask = (g * stacked_params) >= 0 | |
| stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) | |
| class MuonAdamW(torch.optim.Optimizer): | |
| """Combined optimizer: Muon for 2D matrix params, AdamW for others.""" | |
| def __init__(self, param_groups): | |
| super().__init__(param_groups, defaults={}) | |
| # 0-D CPU tensors to avoid torch.compile recompilation when values change | |
| self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") | |
| def _step_adamw(self, group): | |
| params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], [] | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if not state: | |
| state['step'] = 0 | |
| state['exp_avg'] = torch.zeros_like(p) | |
| state['exp_avg_sq'] = torch.zeros_like(p) | |
| if 'step_t' not in state: | |
| # _fused_adamw_ wants a per-param float step tensor on-device. | |
| state['step_t'] = torch.tensor( | |
| float(state['step']), dtype=torch.float32, device=p.device | |
| ) | |
| state['step'] += 1 | |
| params.append(p) | |
| grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad) | |
| exp_avgs.append(state['exp_avg']) | |
| exp_avg_sqs.append(state['exp_avg_sq']) | |
| state_steps.append(state['step_t']) | |
| if not params: | |
| return | |
| if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda: | |
| # _fused_adamw_ needs uniform (device, dtype) within a call, so | |
| # group by (device, dtype) — same pattern as PyTorch's own | |
| # AdamW(fused=True) path (_group_tensors_by_device_and_dtype). | |
| buckets = {} | |
| for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps): | |
| key = (p.device, p.dtype) | |
| buckets.setdefault(key, ([], [], [], [], [])) | |
| b_p, b_g, b_ea, b_es, b_st = buckets[key] | |
| b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st) | |
| lr_f = float(group['lr']) | |
| b1_f = float(group['betas'][0]) | |
| b2_f = float(group['betas'][1]) | |
| wd_f = float(group['weight_decay']) | |
| eps_f = float(group['eps']) | |
| for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items(): | |
| torch._foreach_add_(b_st, 1.0) | |
| torch._fused_adamw_( | |
| b_p, b_g, b_ea, b_es, | |
| [], # max_exp_avg_sqs unused (amsgrad=False) | |
| b_st, | |
| amsgrad=False, | |
| lr=lr_f, beta1=b1_f, beta2=b2_f, | |
| weight_decay=wd_f, eps=eps_f, | |
| maximize=False, | |
| grad_scale=None, found_inf=None, | |
| ) | |
| return | |
| # Fallback per-param path. | |
| self._adamw_lr_t.fill_(group['lr']) | |
| self._adamw_beta1_t.fill_(group['betas'][0]) | |
| self._adamw_beta2_t.fill_(group['betas'][1]) | |
| self._adamw_eps_t.fill_(group['eps']) | |
| self._adamw_wd_t.fill_(group['weight_decay']) | |
| for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs): | |
| self._adamw_step_t.fill_(self.state[p]['step']) | |
| adamw_step_fused(p, grad, exp_avg, exp_avg_sq, | |
| self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, | |
| self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) | |
| def _step_muon(self, group): | |
| params = [p for p in group['params'] if p.grad is not None] | |
| if not params: | |
| return | |
| p = params[0] | |
| state = self.state[p] | |
| num_params = len(params) | |
| shape, device, dtype = p.shape, p.device, p.dtype | |
| if "momentum_buffer" not in state: | |
| state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) | |
| red_dim = -1 if shape[-2] >= shape[-1] else -2 | |
| if "second_momentum_buffer" not in state: | |
| # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True) | |
| full_shape = (num_params, *shape) | |
| state_shape = list(full_shape) | |
| state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative | |
| state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) | |
| # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf). | |
| # This was the autograd-safety fix that unblocks grad_accum>=2. | |
| stacked_grads = torch.stack([p.grad for p in params]) | |
| stacked_params = torch.stack(params) | |
| self._muon_momentum_t.fill_(group["momentum"]) | |
| self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) | |
| self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5) | |
| self._muon_wd_t.fill_(group["weight_decay"]) | |
| muon_step_fused(stacked_grads, stacked_params, | |
| state["momentum_buffer"], state["second_momentum_buffer"], | |
| self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, | |
| self._muon_beta2_t, group["ns_steps"], red_dim) | |
| torch._foreach_copy_(params, list(stacked_params.unbind(0))) | |
| def step(self): | |
| for group in self.param_groups: | |
| if group['kind'] == 'adamw': | |
| self._step_adamw(group) | |
| elif group['kind'] == 'muon': | |
| self._step_muon(group) | |