icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
F1-F15 state preserved:
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
fresh. Persistent copies of param storage would be mutated between forward
passes (via lerp_/sub_ on stacked tensors that share storage with params),
triggering "modified in-place" errors on grad_accum=2 backwards.
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
investigation.
"""
from __future__ import annotations
import os
import torch
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
# for the whole group) driven from MuonAdamW._step_adamw below.
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# ---------------------------------------------------------------------------
# F15 muon_step_fused compile strategy.
#
# HYDRA_MUON_COMPILE env gate:
# "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
# Dynamic=True collapses the per-shape specialization cache so that N
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
# mode="default" keeps the inductor codegen but disables cudagraphs,
# which is what caused the step-17→18 silent deadlock observed under
# the original dynamic=False configuration: cudagraph stream capture
# can deadlock against HTM's CUDA kernels running on the default
# stream, and the failure mode at capture-time is a silent hang
# (100% GPU util, no log output, process state R).
# "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
# Keeps an escape hatch in case a future torch/inductor regression
# reintroduces a deadlock.
#
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
# alias-analysis edge case where inductor sees `g is stacked_grads` and
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
# ---------------------------------------------------------------------------
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
def _maybe_compile(fn):
if _MUON_COMPILE:
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
# kernels. dynamic=True minimizes recompile count across param-group
# shapes.
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
return fn
@_maybe_compile
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Cast grads to param dtype AND clone defensively to break any alias
# between the (freshly-stacked) input and the in-place lerp_ below.
# Without this, inductor's alias analysis can emit code that reads from
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
# Nesterov momentum
momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# NorMuon variance reduction
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class MuonAdamW(torch.optim.Optimizer):
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
def __init__(self, param_groups):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group):
params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if 'step_t' not in state:
# _fused_adamw_ wants a per-param float step tensor on-device.
state['step_t'] = torch.tensor(
float(state['step']), dtype=torch.float32, device=p.device
)
state['step'] += 1
params.append(p)
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step_t'])
if not params:
return
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
# _fused_adamw_ needs uniform (device, dtype) within a call, so
# group by (device, dtype) — same pattern as PyTorch's own
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
buckets = {}
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
key = (p.device, p.dtype)
buckets.setdefault(key, ([], [], [], [], []))
b_p, b_g, b_ea, b_es, b_st = buckets[key]
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
lr_f = float(group['lr'])
b1_f = float(group['betas'][0])
b2_f = float(group['betas'][1])
wd_f = float(group['weight_decay'])
eps_f = float(group['eps'])
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
torch._foreach_add_(b_st, 1.0)
torch._fused_adamw_(
b_p, b_g, b_ea, b_es,
[], # max_exp_avg_sqs unused (amsgrad=False)
b_st,
amsgrad=False,
lr=lr_f, beta1=b1_f, beta2=b2_f,
weight_decay=wd_f, eps=eps_f,
maximize=False,
grad_scale=None, found_inf=None,
)
return
# Fallback per-param path.
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
self._adamw_step_t.fill_(self.state[p]['step'])
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
def _step_muon(self, group):
params = [p for p in group['params'] if p.grad is not None]
if not params:
return
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
if "second_momentum_buffer" not in state:
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
full_shape = (num_params, *shape)
state_shape = list(full_shape)
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
# This was the autograd-safety fix that unblocks grad_accum>=2.
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(stacked_grads, stacked_params,
state["momentum_buffer"], state["second_momentum_buffer"],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
self._muon_beta2_t, group["ns_steps"], red_dim)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
kind = group['kind']
# Audit 2026-05-09 issue #19: 'dt_bias' (and the pre-existing
# 'retina_contrastive') are AdamW-style groups with their own
# cosine-LR exemption upstream. Route them through _step_adamw
# so they actually update; the exempt-from-cosine treatment is
# applied in training.py where group['lr'] is set.
if kind == 'adamw' or kind == 'dt_bias' or kind == 'retina_contrastive':
self._step_adamw(group)
elif kind == 'muon':
self._step_muon(group)