File size: 16,306 Bytes
c383594 c475135 c383594 c475135 c383594 c475135 c383594 c475135 c383594 c475135 c383594 c475135 c383594 c475135 c383594 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 | """MuonAdamW optimizer β combined Muon (2D matrices) + AdamW (everything else).
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
F1-F15 state preserved:
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED β each
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
fresh. Persistent copies of param storage would be mutated between forward
passes (via lerp_/sub_ on stacked tensors that share storage with params),
triggering "modified in-place" errors on grad_accum=2 backwards.
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
dynamic=True + mode="default" to avoid the step-17β18 cudagraphs
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
investigation.
"""
from __future__ import annotations
import os
import torch
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
# for the whole group) driven from MuonAdamW._step_adamw below.
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# ---------------------------------------------------------------------------
# F15 muon_step_fused compile strategy.
#
# HYDRA_MUON_COMPILE env gate:
# "1" (default ON) β wrap with torch.compile(dynamic=True, mode="default").
# Dynamic=True collapses the per-shape specialization cache so that N
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
# mode="default" keeps the inductor codegen but disables cudagraphs,
# which is what caused the step-17β18 silent deadlock observed under
# the original dynamic=False configuration: cudagraph stream capture
# can deadlock against HTM's CUDA kernels running on the default
# stream, and the failure mode at capture-time is a silent hang
# (100% GPU util, no log output, process state R).
# "0" β fall back to eager Python (slower, ~43k tps vs ~63k compiled).
# Keeps an escape hatch in case a future torch/inductor regression
# reintroduces a deadlock.
#
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
# alias-analysis edge case where inductor sees `g is stacked_grads` and
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
# ---------------------------------------------------------------------------
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
def _maybe_compile(fn):
if _MUON_COMPILE:
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
# kernels. dynamic=True minimizes recompile count across param-group
# shapes.
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
return fn
@_maybe_compile
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Cast grads to param dtype AND clone defensively to break any alias
# between the (freshly-stacked) input and the in-place lerp_ below.
# Without this, inductor's alias analysis can emit code that reads from
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
# Nesterov momentum
momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# NorMuon variance reduction
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class MuonAdamW(torch.optim.Optimizer):
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
def __init__(self, param_groups):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_bucket_caches = {}
self._muon_params_caches = {}
def state_dict(self):
sd = super().state_dict()
# Transient fused-step caches and device step_t tensors must not enter
# checkpoints. step_t is recreated from scalar state['step'] lazily.
for st in sd.get("state", {}).values():
st.pop("step_t", None)
for group in sd.get("param_groups", []):
group.pop("_adamw_bucket_cache", None)
group.pop("_muon_params_cache", None)
return sd
def load_state_dict(self, state_dict):
for st in state_dict.get("state", {}).values():
st.pop("step_t", None)
for group in state_dict.get("param_groups", []):
group.pop("_adamw_bucket_cache", None)
group.pop("_muon_params_cache", None)
self._adamw_bucket_caches.clear()
self._muon_params_caches.clear()
return super().load_state_dict(state_dict)
def _ensure_adamw_state(self, p):
state = self.state[p]
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if 'step_t' not in state:
# _fused_adamw_ wants a per-param float step tensor on-device.
state['step_t'] = torch.tensor(
float(state['step']), dtype=torch.float32, device=p.device
)
return state
def _adamw_cached_buckets(self, group):
"""Return stable (device,dtype) param buckets for fused AdamW.
Cache topology only. Optimizer state remains lazy for grad-bearing
params so unused/frozen tensors do not bloat checkpoints.
"""
params_tuple = tuple(group['params'])
cache = self._adamw_bucket_caches.get(id(group))
if cache is not None and cache.get('params_tuple') == params_tuple:
return cache['buckets']
buckets = {}
for p in params_tuple:
key = (p.device, p.dtype)
buckets.setdefault(key, {'params': []})
buckets[key]['params'].append(p)
self._adamw_bucket_caches[id(group)] = {'params_tuple': params_tuple, 'buckets': buckets}
return buckets
def _step_adamw(self, group):
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW:
# Mixed CPU/CUDA groups are unusual in Feather but skipping CPU
# grads would be a correctness bug; disable fused path in that case.
if not any(p.grad is not None and not p.is_cuda for p in group['params']):
buckets = self._adamw_cached_buckets(group)
lr_f = float(group['lr'])
b1_f = float(group['betas'][0])
b2_f = float(group['betas'][1])
wd_f = float(group['weight_decay'])
eps_f = float(group['eps'])
launched = False
for (_dev, _dt), bucket in buckets.items():
b_p = [p for p in bucket['params'] if p.grad is not None]
if not b_p or not b_p[0].is_cuda:
continue
b_g = [p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad for p in b_p]
b_ea, b_es, b_st = [], [], []
for p in b_p:
state = self._ensure_adamw_state(p)
state['step'] += 1
b_ea.append(state['exp_avg'])
b_es.append(state['exp_avg_sq'])
b_st.append(state['step_t'])
torch._foreach_add_(b_st, 1.0)
torch._fused_adamw_(
b_p, b_g, b_ea, b_es,
[], # max_exp_avg_sqs unused (amsgrad=False)
b_st,
amsgrad=False,
lr=lr_f, beta1=b1_f, beta2=b2_f,
weight_decay=wd_f, eps=eps_f,
maximize=False,
grad_scale=None, found_inf=None,
)
launched = True
if launched:
return
params, grads, exp_avgs, exp_avg_sqs = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
state = self._ensure_adamw_state(p)
state['step'] += 1
if 'step_t' in state:
state['step_t'].fill_(float(state['step']))
params.append(p)
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if not params:
return
# Fallback per-param path.
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
self._adamw_step_t.fill_(self.state[p]['step'])
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
def _step_muon(self, group):
params_tuple = tuple(group['params'])
cache = self._muon_params_caches.get(id(group))
if cache is None or cache.get('params_tuple') != params_tuple:
cache = {'params_tuple': params_tuple, 'params': list(params_tuple)}
self._muon_params_caches[id(group)] = cache
params_all = cache['params']
# Common Feather path: all Muon matrix params receive grads every step.
# Preserve sparse/None-grad correctness by filtering only when needed.
if all(p.grad is not None for p in params_all):
params = params_all
else:
params = [p for p in params_all if p.grad is not None]
if not params:
return
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
if (
"momentum_buffer" not in state
or state["momentum_buffer"].shape[0] != num_params
or tuple(state["momentum_buffer"].shape[1:]) != tuple(shape)
):
# If grad-bearing Muon params change (rare; usually all matrix params
# have grads), resize instead of crashing compiled Muon on a stale
# leading dimension. This preserves skip-None-grad semantics.
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
state.pop("second_momentum_buffer", None)
red_dim = -1 if shape[-2] >= shape[-1] else -2
if "second_momentum_buffer" not in state:
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
full_shape = (num_params, *shape)
state_shape = list(full_shape)
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
# This was the autograd-safety fix that unblocks grad_accum>=2.
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(stacked_grads, stacked_params,
state["momentum_buffer"], state["second_momentum_buffer"],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
self._muon_beta2_t, group["ns_steps"], red_dim)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
kind = group['kind']
# Audit 2026-05-09 issue #19: 'dt_bias' (and the pre-existing
# 'retina_contrastive') are AdamW-style groups with their own
# cosine-LR exemption upstream. Route them through _step_adamw
# so they actually update; the exempt-from-cosine treatment is
# applied in training.py where group['lr'] is set.
if kind == 'adamw' or kind == 'dt_bias' or kind == 'retina_contrastive':
self._step_adamw(group)
elif kind == 'muon':
self._step_muon(group)
|