Spaces:
Runtime error
Runtime error
File size: 12,509 Bytes
e317e25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """MuonAdamW optimizer β combined Muon (2D matrices) + AdamW (everything else).
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
F1-F15 state preserved:
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED β each
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
fresh. Persistent copies of param storage would be mutated between forward
passes (via lerp_/sub_ on stacked tensors that share storage with params),
triggering "modified in-place" errors on grad_accum=2 backwards.
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
dynamic=True + mode="default" to avoid the step-17β18 cudagraphs
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
investigation.
"""
from __future__ import annotations
import os
import torch
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
# for the whole group) driven from MuonAdamW._step_adamw below.
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# ---------------------------------------------------------------------------
# F15 muon_step_fused compile strategy.
#
# HYDRA_MUON_COMPILE env gate:
# "1" (default ON) β wrap with torch.compile(dynamic=True, mode="default").
# Dynamic=True collapses the per-shape specialization cache so that N
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
# mode="default" keeps the inductor codegen but disables cudagraphs,
# which is what caused the step-17β18 silent deadlock observed under
# the original dynamic=False configuration: cudagraph stream capture
# can deadlock against HTM's CUDA kernels running on the default
# stream, and the failure mode at capture-time is a silent hang
# (100% GPU util, no log output, process state R).
# "0" β fall back to eager Python (slower, ~43k tps vs ~63k compiled).
# Keeps an escape hatch in case a future torch/inductor regression
# reintroduces a deadlock.
#
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
# alias-analysis edge case where inductor sees `g is stacked_grads` and
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
# ---------------------------------------------------------------------------
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
def _maybe_compile(fn):
if _MUON_COMPILE:
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
# kernels. dynamic=True minimizes recompile count across param-group
# shapes.
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
return fn
@_maybe_compile
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Cast grads to param dtype AND clone defensively to break any alias
# between the (freshly-stacked) input and the in-place lerp_ below.
# Without this, inductor's alias analysis can emit code that reads from
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
# Nesterov momentum
momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# NorMuon variance reduction
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class MuonAdamW(torch.optim.Optimizer):
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
def __init__(self, param_groups):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group):
params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if 'step_t' not in state:
# _fused_adamw_ wants a per-param float step tensor on-device.
state['step_t'] = torch.tensor(
float(state['step']), dtype=torch.float32, device=p.device
)
state['step'] += 1
params.append(p)
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step_t'])
if not params:
return
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
# _fused_adamw_ needs uniform (device, dtype) within a call, so
# group by (device, dtype) β same pattern as PyTorch's own
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
buckets = {}
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
key = (p.device, p.dtype)
buckets.setdefault(key, ([], [], [], [], []))
b_p, b_g, b_ea, b_es, b_st = buckets[key]
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
lr_f = float(group['lr'])
b1_f = float(group['betas'][0])
b2_f = float(group['betas'][1])
wd_f = float(group['weight_decay'])
eps_f = float(group['eps'])
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
torch._foreach_add_(b_st, 1.0)
torch._fused_adamw_(
b_p, b_g, b_ea, b_es,
[], # max_exp_avg_sqs unused (amsgrad=False)
b_st,
amsgrad=False,
lr=lr_f, beta1=b1_f, beta2=b2_f,
weight_decay=wd_f, eps=eps_f,
maximize=False,
grad_scale=None, found_inf=None,
)
return
# Fallback per-param path.
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
self._adamw_step_t.fill_(self.state[p]['step'])
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
def _step_muon(self, group):
params = [p for p in group['params'] if p.grad is not None]
if not params:
return
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
if "second_momentum_buffer" not in state:
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
full_shape = (num_params, *shape)
state_shape = list(full_shape)
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
# This was the autograd-safety fix that unblocks grad_accum>=2.
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(stacked_grads, stacked_params,
state["momentum_buffer"], state["second_momentum_buffer"],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
self._muon_beta2_t, group["ns_steps"], red_dim)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
|