Muon optimizer: expert batching, parallel caching, A2A overlap [skip-build]
Browse files- Batched expert NS path for plain-tensor MoE params (skip expansion)
- Expert expansion cache to eliminate per-step detach overhead
- _setup_parallel() extraction for parallel metadata reuse
- Prelaunch first chunk A2A gather to overlap with expert NS compute
- Profiler annotations and clarify distributed_muon as test-only
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- torch-ext/optimizer/muon.py +475 -98
torch-ext/optimizer/muon.py
CHANGED
|
@@ -10,14 +10,15 @@ from torch.profiler import record_function
|
|
| 10 |
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
-
from .core import (_muon_state, adjust_lr_for_muon,
|
| 14 |
-
get_default_muon_param_groups, is_expert_param,
|
| 15 |
-
update_p)
|
| 16 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 17 |
get_slices_of_dtensor)
|
| 18 |
from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
|
| 19 |
-
_zeropower_via_newtonschulz5
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
|
@@ -49,6 +50,18 @@ def _expand_expert_params(names, params, expert_keys):
|
|
| 49 |
is_expert = is_expert_param(n, expert_keys)
|
| 50 |
is_dtensor = isinstance(p.data, DTensor)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if not is_expert:
|
| 53 |
assert p.data.ndim <= 2, (
|
| 54 |
f"Param {n} has ndim={p.data.ndim} but does not match "
|
|
@@ -169,7 +182,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 169 |
Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
|
| 170 |
use_distributed_muon: Use distributed muon by Liu et al. (2024).
|
| 171 |
For testing purpose only.
|
| 172 |
-
small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
|
| 173 |
expert_keys: List of strings to identify expert-parallel parameters.
|
| 174 |
If any key appears in a parameter's name, its outermost
|
| 175 |
dimension is treated as the expert dimension and expanded
|
|
@@ -194,7 +206,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 194 |
warmup_step=5,
|
| 195 |
chunk_size=-1,
|
| 196 |
use_distributed_muon=False,
|
| 197 |
-
small_param_numel_threshold=65536,
|
| 198 |
expert_keys=None):
|
| 199 |
defaults = dict(
|
| 200 |
lr=lr,
|
|
@@ -229,8 +240,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
self.warmup_step = warmup_step
|
| 230 |
self.chunk_size = chunk_size
|
| 231 |
self.use_distributed_muon = use_distributed_muon
|
| 232 |
-
self.small_param_numel_threshold = small_param_numel_threshold
|
| 233 |
self.expert_keys = expert_keys
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def _calc_flops(self, G, steps):
|
| 236 |
assert len(G.shape) == 2
|
|
@@ -334,8 +346,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 334 |
if g is None:
|
| 335 |
continue
|
| 336 |
|
| 337 |
-
u =
|
| 338 |
-
|
| 339 |
|
| 340 |
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
|
| 341 |
update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -356,52 +368,269 @@ class Muon(torch.optim.Optimizer):
|
|
| 356 |
weight_decay: float,
|
| 357 |
qk_logits: list[torch.Tensor | DTensor] | None,
|
| 358 |
):
|
| 359 |
-
"""
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
#
|
| 368 |
-
|
| 369 |
-
g_full = g.full_tensor()
|
| 370 |
-
p_full = p.data.full_tensor()
|
| 371 |
-
else:
|
| 372 |
-
g_full = g
|
| 373 |
-
p_full = p
|
| 374 |
-
|
| 375 |
-
u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
|
| 376 |
-
steps=group["ns_steps"])
|
| 377 |
-
|
| 378 |
-
adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
|
| 379 |
-
update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
|
| 380 |
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
p_replicate = DTensor.from_local(
|
| 392 |
-
p_full,
|
| 393 |
-
device_mesh=p.device_mesh,
|
| 394 |
-
placements=[Replicate() for _ in range(ndims)],
|
| 395 |
-
)
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
)
|
| 401 |
|
| 402 |
-
|
| 403 |
|
| 404 |
-
def parallel(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
"""
|
| 406 |
Perform a parallel optimization step using Muon.
|
| 407 |
|
|
@@ -410,31 +639,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 410 |
interleaves multiple chunks so that communication and computation
|
| 411 |
overlap across chunks (the same overlap previously achieved by the
|
| 412 |
warmup + main-loop index scheduling).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
"""
|
| 414 |
|
| 415 |
# Momentum is already applied by _step_muon before this method.
|
| 416 |
|
| 417 |
-
param_to_state,
|
| 418 |
-
names, params, group, qk_logits)
|
| 419 |
-
|
| 420 |
-
# Compute local rank for this group's shard process group.
|
| 421 |
-
shard_pg = param_to_state[id(ordered_params[0])].process_group
|
| 422 |
-
rank = dist.get_rank(group=shard_pg)
|
| 423 |
-
|
| 424 |
-
if self.chunk_size == -1:
|
| 425 |
-
shard_ranks = dist.get_world_size(param_to_state[id(
|
| 426 |
-
ordered_params[0])].process_group)
|
| 427 |
-
chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
|
| 428 |
-
elif self.chunk_size > 0:
|
| 429 |
-
chunk_size = self.chunk_size
|
| 430 |
-
else:
|
| 431 |
-
raise ValueError("chunk_size must be -1 or a positive integer.")
|
| 432 |
|
| 433 |
def pipelines():
|
|
|
|
| 434 |
for start in range(0, len(ordered_params), chunk_size):
|
| 435 |
chunk = ordered_params[start:start + chunk_size]
|
| 436 |
if chunk:
|
| 437 |
-
|
| 438 |
params=chunk,
|
| 439 |
param_to_state=param_to_state,
|
| 440 |
rank=rank,
|
|
@@ -443,9 +664,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 443 |
weight_decay=weight_decay,
|
| 444 |
none_grad=group["none_grad"],
|
| 445 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
-
with record_function("muon::barrier"):
|
| 448 |
-
dist.barrier()
|
| 449 |
with record_function("muon::pipeline"):
|
| 450 |
run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
|
| 451 |
|
|
@@ -457,16 +680,152 @@ class Muon(torch.optim.Optimizer):
|
|
| 457 |
names = group["names"]
|
| 458 |
|
| 459 |
# Apply momentum to all params before routing/expansion.
|
|
|
|
| 460 |
with record_function("muon::momentum"):
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
continue
|
| 465 |
-
|
| 466 |
-
p.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
# Expand expert params by splitting on dim 0.
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
param_dtensors = []
|
| 472 |
name_dtensors = []
|
|
@@ -474,10 +833,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 474 |
param_tensors = []
|
| 475 |
name_tensors = []
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
if self.use_distributed_muon:
|
|
|
|
| 481 |
self.distributed_muon(names=names,
|
| 482 |
params=params,
|
| 483 |
group=group,
|
|
@@ -486,8 +845,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 486 |
qk_logits=qk_logits)
|
| 487 |
return
|
| 488 |
|
| 489 |
-
# For simplicity, we use distributed Muon for small parameters
|
| 490 |
-
# whose number of elements is below a threshold.
|
| 491 |
for n, p in zip(names, params):
|
| 492 |
if p is None or p.grad is None:
|
| 493 |
continue
|
|
@@ -495,23 +852,28 @@ class Muon(torch.optim.Optimizer):
|
|
| 495 |
if all(
|
| 496 |
isinstance(placement, Replicate)
|
| 497 |
for placement in p.placements):
|
|
|
|
|
|
|
|
|
|
| 498 |
param_tensors.append(p)
|
| 499 |
name_tensors.append(n)
|
| 500 |
-
elif p.data.numel() <= self.small_param_numel_threshold:
|
| 501 |
-
param_dtensors_small.append(p)
|
| 502 |
-
name_dtensors_small.append(n)
|
| 503 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
param_dtensors.append(p)
|
| 505 |
name_dtensors.append(n)
|
| 506 |
elif isinstance(p.data, torch.Tensor):
|
|
|
|
|
|
|
| 507 |
param_tensors.append(p)
|
| 508 |
name_tensors.append(n)
|
| 509 |
else:
|
| 510 |
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 511 |
|
| 512 |
-
logger.debug(
|
| 513 |
-
|
| 514 |
-
f"{len(param_dtensors_small)} Small DTensors")
|
| 515 |
|
| 516 |
def group_dtensors(dtensors, names):
|
| 517 |
# To support different placements, we group parameters by placements
|
|
@@ -527,21 +889,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 527 |
p.device_mesh])][1].append(p)
|
| 528 |
return placement_to_params
|
| 529 |
|
| 530 |
-
if len(param_dtensors_small) > 0:
|
| 531 |
-
if not dist.is_initialized():
|
| 532 |
-
raise RuntimeError(
|
| 533 |
-
"Parallel Muon requires torch.distributed to be initialized."
|
| 534 |
-
)
|
| 535 |
-
|
| 536 |
-
self.distributed_muon(
|
| 537 |
-
params=param_dtensors_small,
|
| 538 |
-
names=name_dtensors_small,
|
| 539 |
-
group=group,
|
| 540 |
-
lr=lr,
|
| 541 |
-
weight_decay=weight_decay,
|
| 542 |
-
qk_logits=qk_logits,
|
| 543 |
-
)
|
| 544 |
-
|
| 545 |
if len(param_dtensors) > 0:
|
| 546 |
if not dist.is_initialized():
|
| 547 |
raise RuntimeError(
|
|
@@ -549,7 +896,26 @@ class Muon(torch.optim.Optimizer):
|
|
| 549 |
)
|
| 550 |
|
| 551 |
dtensor_group = group_dtensors(param_dtensors, name_dtensors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
for _, (names, params) in dtensor_group.items():
|
|
|
|
|
|
|
| 553 |
self.parallel(
|
| 554 |
names,
|
| 555 |
params,
|
|
@@ -557,7 +923,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 557 |
lr=lr,
|
| 558 |
weight_decay=weight_decay,
|
| 559 |
qk_logits=qk_logits,
|
|
|
|
| 560 |
)
|
|
|
|
|
|
|
| 561 |
|
| 562 |
if len(param_tensors) > 0:
|
| 563 |
self.base(
|
|
@@ -586,10 +955,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 586 |
with torch.enable_grad():
|
| 587 |
loss = closure()
|
| 588 |
|
| 589 |
-
|
|
|
|
|
|
|
|
|
|
| 590 |
if group["use_muon"]:
|
|
|
|
|
|
|
| 591 |
self._step_muon(group, qk_logits=qk_logits)
|
| 592 |
else:
|
|
|
|
|
|
|
|
|
|
| 593 |
step_adamw(self.state, group)
|
| 594 |
|
| 595 |
return loss
|
|
|
|
| 10 |
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
+
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
+
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 15 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 16 |
get_slices_of_dtensor)
|
| 17 |
from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
|
| 18 |
+
_zeropower_via_newtonschulz5,
|
| 19 |
+
zeropower_via_newtonschulz5,
|
| 20 |
+
zeropower_via_newtonschulz5_batched)
|
| 21 |
+
from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
|
| 22 |
from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
|
|
|
| 50 |
is_expert = is_expert_param(n, expert_keys)
|
| 51 |
is_dtensor = isinstance(p.data, DTensor)
|
| 52 |
|
| 53 |
+
if is_expert:
|
| 54 |
+
if is_dtensor:
|
| 55 |
+
logger.debug(
|
| 56 |
+
"[expand_expert] %s: expert DTensor, shape=%s, "
|
| 57 |
+
"placements=%s, mesh=%s, local_shape=%s", n, p.shape,
|
| 58 |
+
p.placements, p.device_mesh.mesh_dim_names,
|
| 59 |
+
p.to_local().shape)
|
| 60 |
+
else:
|
| 61 |
+
logger.debug(
|
| 62 |
+
"[expand_expert] %s: expert plain tensor, shape=%s", n,
|
| 63 |
+
p.data.shape)
|
| 64 |
+
|
| 65 |
if not is_expert:
|
| 66 |
assert p.data.ndim <= 2, (
|
| 67 |
f"Param {n} has ndim={p.data.ndim} but does not match "
|
|
|
|
| 182 |
Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
|
| 183 |
use_distributed_muon: Use distributed muon by Liu et al. (2024).
|
| 184 |
For testing purpose only.
|
|
|
|
| 185 |
expert_keys: List of strings to identify expert-parallel parameters.
|
| 186 |
If any key appears in a parameter's name, its outermost
|
| 187 |
dimension is treated as the expert dimension and expanded
|
|
|
|
| 206 |
warmup_step=5,
|
| 207 |
chunk_size=-1,
|
| 208 |
use_distributed_muon=False,
|
|
|
|
| 209 |
expert_keys=None):
|
| 210 |
defaults = dict(
|
| 211 |
lr=lr,
|
|
|
|
| 240 |
self.warmup_step = warmup_step
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
|
|
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 245 |
+
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
| 246 |
|
| 247 |
def _calc_flops(self, G, steps):
|
| 248 |
assert len(G.shape) == 2
|
|
|
|
| 346 |
if g is None:
|
| 347 |
continue
|
| 348 |
|
| 349 |
+
u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
|
| 350 |
+
steps=group["ns_steps"])
|
| 351 |
|
| 352 |
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
|
| 353 |
update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 368 |
weight_decay: float,
|
| 369 |
qk_logits: list[torch.Tensor | DTensor] | None,
|
| 370 |
):
|
| 371 |
+
"""Batched Distributed Muon — for testing/correctness verification only.
|
| 372 |
|
| 373 |
+
Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
|
| 374 |
+
the full grad, then slices back to local shards. This is simpler but
|
| 375 |
+
slower than the parallel pipeline (all2all) path, so it serves as a
|
| 376 |
+
reference implementation for verifying correctness.
|
| 377 |
+
"""
|
| 378 |
+
with record_function("distributed_muon"):
|
| 379 |
+
# Momentum is already applied by _step_muon before this method.
|
| 380 |
+
ns_steps = group["ns_steps"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
# Separate plain tensors (no communication) from DTensors.
|
| 383 |
+
plain_names, plain_params = [], []
|
| 384 |
+
dtensor_names, dtensor_params = [], []
|
| 385 |
+
for n, p in zip(names, params):
|
| 386 |
+
if p.grad is None:
|
| 387 |
+
continue
|
| 388 |
+
if isinstance(p.data, DTensor):
|
| 389 |
+
dtensor_names.append(n)
|
| 390 |
+
dtensor_params.append(p)
|
| 391 |
+
else:
|
| 392 |
+
plain_names.append(n)
|
| 393 |
+
plain_params.append(p)
|
| 394 |
+
|
| 395 |
+
# Process plain tensors per-param (no communication).
|
| 396 |
+
for n, p in zip(plain_names, plain_params):
|
| 397 |
+
u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
|
| 398 |
+
steps=ns_steps)
|
| 399 |
+
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
|
| 400 |
+
update_p(p, u, lr, adjusted_lr, weight_decay)
|
| 401 |
+
|
| 402 |
+
qk_clip_state = get_qk_clip_info(self.clip_config, n,
|
| 403 |
+
qk_logits)
|
| 404 |
+
scales_full = compute_scales(
|
| 405 |
+
p, qk_clip_state) if qk_clip_state is not None else None
|
| 406 |
+
if scales_full is not None:
|
| 407 |
+
qk_clip(p, scales_full, qk_clip_state.head_dim)
|
| 408 |
+
|
| 409 |
+
if not dtensor_params:
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
# Group DTensors by (placements, mesh) for batched all-gather.
|
| 413 |
+
placement_groups: dict[tuple,
|
| 414 |
+
tuple[list,
|
| 415 |
+
list]] = defaultdict(lambda: ([], []))
|
| 416 |
+
for n, p in zip(dtensor_names, dtensor_params):
|
| 417 |
+
key = (p.placements, p.device_mesh)
|
| 418 |
+
placement_groups[key][0].append(n)
|
| 419 |
+
placement_groups[key][1].append(p)
|
| 420 |
+
|
| 421 |
+
logger.info(
|
| 422 |
+
"distributed_muon: %d placement groups, %d total dtensors",
|
| 423 |
+
len(placement_groups), len(dtensor_params))
|
| 424 |
+
|
| 425 |
+
for (placements, mesh), (grp_names,
|
| 426 |
+
grp_params) in placement_groups.items():
|
| 427 |
+
shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
|
| 428 |
+
placements, mesh)
|
| 429 |
+
rank = dist.get_rank(shard_pg)
|
| 430 |
+
world_size = dist.get_world_size(shard_pg)
|
| 431 |
+
|
| 432 |
+
logger.info(" group: %d params, placements=%s, world_size=%d",
|
| 433 |
+
len(grp_params), placements, world_size)
|
| 434 |
+
|
| 435 |
+
# Separate params that can be batched (all shard dims evenly
|
| 436 |
+
# divisible) from those needing per-param full_tensor
|
| 437 |
+
# (e.g. MoE gate weights with fewer rows than shard ranks).
|
| 438 |
+
# all_gather_into_tensor requires equal buffer sizes across
|
| 439 |
+
# ranks, so uneven splits must use DTensor full_tensor().
|
| 440 |
+
batch_names, batch_params = [], []
|
| 441 |
+
single_names, single_params = [], []
|
| 442 |
+
for n, p in zip(grp_names, grp_params):
|
| 443 |
+
even = all(p.shape[pl.dim] %
|
| 444 |
+
shard_mesh.mesh.shape[dim_idx] == 0
|
| 445 |
+
for dim_idx, pl in enumerate(shard_placements))
|
| 446 |
+
if even:
|
| 447 |
+
batch_names.append(n)
|
| 448 |
+
batch_params.append(p)
|
| 449 |
+
else:
|
| 450 |
+
single_names.append(n)
|
| 451 |
+
single_params.append(p)
|
| 452 |
+
|
| 453 |
+
# Process uneven-split params per-param via full_tensor().
|
| 454 |
+
for n, p in zip(single_names, single_params):
|
| 455 |
+
with record_function("distributed_muon::newton_schulz"):
|
| 456 |
+
g_full = p.grad.full_tensor().to(COMM_DTYPE)
|
| 457 |
+
u_full = _zeropower_via_newtonschulz5(g_full,
|
| 458 |
+
steps=ns_steps)
|
| 459 |
+
del g_full
|
| 460 |
+
with record_function("distributed_muon::update"):
|
| 461 |
+
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
|
| 462 |
+
p._local_tensor.mul_(1 - lr * weight_decay)
|
| 463 |
+
local_indices = get_slices_of_dtensor(
|
| 464 |
+
p, rank, shard_mesh, shard_placements)
|
| 465 |
+
u_local = u_full[local_indices]
|
| 466 |
+
p._local_tensor.add_(u_local, alpha=-adjusted_lr)
|
| 467 |
+
del u_full
|
| 468 |
+
|
| 469 |
+
qk_clip_state = get_qk_clip_info(
|
| 470 |
+
self.clip_config, n, qk_logits)
|
| 471 |
+
scales_full = compute_scales(
|
| 472 |
+
p, qk_clip_state
|
| 473 |
+
) if qk_clip_state is not None else None
|
| 474 |
+
if scales_full is not None:
|
| 475 |
+
ratio = p.shape[0] // scales_full.shape[0]
|
| 476 |
+
idx0 = local_indices[0]
|
| 477 |
+
if isinstance(idx0, slice):
|
| 478 |
+
start = idx0.start or 0
|
| 479 |
+
idx0 = torch.arange(start,
|
| 480 |
+
idx0.stop,
|
| 481 |
+
device=scales_full.device)
|
| 482 |
+
row_scales = scales_full[idx0 // ratio]
|
| 483 |
+
p._local_tensor.mul_(row_scales.view(-1, 1))
|
| 484 |
+
|
| 485 |
+
if not batch_params:
|
| 486 |
+
continue
|
| 487 |
|
| 488 |
+
logger.info(" batched=%d, single=%d", len(batch_params),
|
| 489 |
+
len(single_params))
|
| 490 |
+
|
| 491 |
+
# Concat all local grad shards into a single flat buffer.
|
| 492 |
+
with record_function("distributed_muon::gather"):
|
| 493 |
+
grad_locals = [
|
| 494 |
+
p.grad.to_local().to(COMM_DTYPE).flatten()
|
| 495 |
+
for p in batch_params
|
| 496 |
+
]
|
| 497 |
+
numels = [g.numel() for g in grad_locals]
|
| 498 |
+
grad_concat = torch.cat(grad_locals)
|
| 499 |
+
del grad_locals
|
| 500 |
+
|
| 501 |
+
# Single all-gather (replaces N separate full_tensor).
|
| 502 |
+
grad_gathered = torch.empty(
|
| 503 |
+
grad_concat.numel() * world_size,
|
| 504 |
+
dtype=COMM_DTYPE,
|
| 505 |
+
device="cuda",
|
| 506 |
+
)
|
| 507 |
+
dist.all_gather_into_tensor(grad_gathered,
|
| 508 |
+
grad_concat,
|
| 509 |
+
group=shard_pg)
|
| 510 |
+
|
| 511 |
+
total_numel = grad_concat.numel()
|
| 512 |
+
del grad_concat
|
| 513 |
+
|
| 514 |
+
# Precompute per-param offsets within the concat buffer.
|
| 515 |
+
offsets = []
|
| 516 |
+
off = 0
|
| 517 |
+
for ne in numels:
|
| 518 |
+
offsets.append(off)
|
| 519 |
+
off += ne
|
| 520 |
+
|
| 521 |
+
# Per-param: reconstruct full grad → NS → local update.
|
| 522 |
+
for i, (n, p) in enumerate(zip(batch_names, batch_params)):
|
| 523 |
+
with record_function("distributed_muon::newton_schulz"):
|
| 524 |
+
g_full = torch.empty(p.shape,
|
| 525 |
+
dtype=COMM_DTYPE,
|
| 526 |
+
device="cuda")
|
| 527 |
+
for r in range(world_size):
|
| 528 |
+
r_start = r * total_numel + offsets[i]
|
| 529 |
+
shard = grad_gathered[r_start:r_start + numels[i]]
|
| 530 |
+
indices = get_slices_of_dtensor(
|
| 531 |
+
p, r, shard_mesh, shard_placements)
|
| 532 |
+
g_full[indices] = shard.reshape(
|
| 533 |
+
g_full[indices].shape)
|
| 534 |
+
|
| 535 |
+
u_full = _zeropower_via_newtonschulz5(g_full,
|
| 536 |
+
steps=ns_steps)
|
| 537 |
+
del g_full
|
| 538 |
+
|
| 539 |
+
with record_function("distributed_muon::update"):
|
| 540 |
+
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
|
| 541 |
+
p._local_tensor.mul_(1 - lr * weight_decay)
|
| 542 |
+
local_indices = get_slices_of_dtensor(
|
| 543 |
+
p, rank, shard_mesh, shard_placements)
|
| 544 |
+
u_local = u_full[local_indices]
|
| 545 |
+
p._local_tensor.add_(u_local, alpha=-adjusted_lr)
|
| 546 |
+
del u_full
|
| 547 |
+
|
| 548 |
+
qk_clip_state = get_qk_clip_info(
|
| 549 |
+
self.clip_config, n, qk_logits)
|
| 550 |
+
scales_full = compute_scales(
|
| 551 |
+
p, qk_clip_state
|
| 552 |
+
) if qk_clip_state is not None else None
|
| 553 |
+
if scales_full is not None:
|
| 554 |
+
ratio = p.shape[0] // scales_full.shape[0]
|
| 555 |
+
idx0 = local_indices[0]
|
| 556 |
+
if isinstance(idx0, slice):
|
| 557 |
+
start = idx0.start or 0
|
| 558 |
+
idx0 = torch.arange(start,
|
| 559 |
+
idx0.stop,
|
| 560 |
+
device=scales_full.device)
|
| 561 |
+
row_scales = scales_full[idx0 // ratio]
|
| 562 |
+
p._local_tensor.mul_(row_scales.view(-1, 1))
|
| 563 |
+
|
| 564 |
+
def _setup_parallel(self, names, params, group, qk_logits):
|
| 565 |
+
"""Compute (or retrieve cached) parallel pipeline metadata.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
(ordered_params, param_to_state, rank, chunk_size)
|
| 569 |
+
"""
|
| 570 |
+
cache_key = tuple(names)
|
| 571 |
|
| 572 |
+
if cache_key not in self._parallel_cache:
|
| 573 |
+
# First call: compute metadata and populate cache.
|
| 574 |
+
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 575 |
+
names, params, group, qk_logits)
|
| 576 |
|
| 577 |
+
shard_pg = param_to_state[id(ordered_params[0])].process_group
|
| 578 |
+
rank = dist.get_rank(group=shard_pg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
+
if self.chunk_size == -1:
|
| 581 |
+
shard_ranks = dist.get_world_size(shard_pg)
|
| 582 |
+
chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
|
| 583 |
+
elif self.chunk_size > 0:
|
| 584 |
+
chunk_size = self.chunk_size
|
| 585 |
+
else:
|
| 586 |
+
raise ValueError(
|
| 587 |
+
"chunk_size must be -1 or a positive integer.")
|
| 588 |
+
|
| 589 |
+
ordered_names = [
|
| 590 |
+
param_to_state[id(p)].name for p in ordered_params
|
| 591 |
+
]
|
| 592 |
+
name_to_state = {
|
| 593 |
+
param_to_state[id(p)].name: param_to_state[id(p)]
|
| 594 |
+
for p in ordered_params
|
| 595 |
+
}
|
| 596 |
+
self._parallel_cache[cache_key] = {
|
| 597 |
+
'ordered_names': ordered_names,
|
| 598 |
+
'name_to_state': name_to_state,
|
| 599 |
+
'rank': rank,
|
| 600 |
+
'chunk_size': chunk_size,
|
| 601 |
+
}
|
| 602 |
+
else:
|
| 603 |
+
# Cached path: rebuild param_to_state with current id(p) keys.
|
| 604 |
+
cache = self._parallel_cache[cache_key]
|
| 605 |
+
rank = cache['rank']
|
| 606 |
+
chunk_size = cache['chunk_size']
|
| 607 |
+
|
| 608 |
+
name_to_param = dict(zip(names, params))
|
| 609 |
+
ordered_params = [name_to_param[n] for n in cache['ordered_names']]
|
| 610 |
+
|
| 611 |
+
param_to_state = {}
|
| 612 |
+
for p, n in zip(ordered_params, cache['ordered_names']):
|
| 613 |
+
cached_state = cache['name_to_state'][n]
|
| 614 |
+
param_to_state[id(p)] = _muon_state(
|
| 615 |
+
worker_rank=cached_state.worker_rank,
|
| 616 |
+
process_group=cached_state.process_group,
|
| 617 |
+
rank_indices=cached_state.rank_indices,
|
| 618 |
+
rank_numels=cached_state.rank_numels,
|
| 619 |
+
name=n,
|
| 620 |
+
qk_clip_state=get_qk_clip_info(self.clip_config, n,
|
| 621 |
+
qk_logits),
|
| 622 |
)
|
| 623 |
|
| 624 |
+
return ordered_params, param_to_state, rank, chunk_size
|
| 625 |
|
| 626 |
+
def parallel(self,
|
| 627 |
+
names,
|
| 628 |
+
params,
|
| 629 |
+
group,
|
| 630 |
+
lr,
|
| 631 |
+
weight_decay,
|
| 632 |
+
qk_logits,
|
| 633 |
+
prelaunch_gather=None):
|
| 634 |
"""
|
| 635 |
Perform a parallel optimization step using Muon.
|
| 636 |
|
|
|
|
| 639 |
interleaves multiple chunks so that communication and computation
|
| 640 |
overlap across chunks (the same overlap previously achieved by the
|
| 641 |
warmup + main-loop index scheduling).
|
| 642 |
+
|
| 643 |
+
If ``prelaunch_gather`` is provided, it is passed to the first
|
| 644 |
+
chunk's generator to skip re-launching the already in-flight
|
| 645 |
+
A2A gather.
|
| 646 |
"""
|
| 647 |
|
| 648 |
# Momentum is already applied by _step_muon before this method.
|
| 649 |
|
| 650 |
+
ordered_params, param_to_state, rank, chunk_size = (
|
| 651 |
+
self._setup_parallel(names, params, group, qk_logits))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
|
| 653 |
def pipelines():
|
| 654 |
+
first = True
|
| 655 |
for start in range(0, len(ordered_params), chunk_size):
|
| 656 |
chunk = ordered_params[start:start + chunk_size]
|
| 657 |
if chunk:
|
| 658 |
+
kwargs = dict(
|
| 659 |
params=chunk,
|
| 660 |
param_to_state=param_to_state,
|
| 661 |
rank=rank,
|
|
|
|
| 664 |
weight_decay=weight_decay,
|
| 665 |
none_grad=group["none_grad"],
|
| 666 |
)
|
| 667 |
+
if first and prelaunch_gather is not None:
|
| 668 |
+
kwargs['prelaunch_gather'] = prelaunch_gather
|
| 669 |
+
first = False
|
| 670 |
+
yield muon_chunk_pipeline(**kwargs)
|
| 671 |
|
|
|
|
|
|
|
| 672 |
with record_function("muon::pipeline"):
|
| 673 |
run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
|
| 674 |
|
|
|
|
| 680 |
names = group["names"]
|
| 681 |
|
| 682 |
# Apply momentum to all params before routing/expansion.
|
| 683 |
+
# Batched using _foreach_* ops (compiled, fullgraph=True).
|
| 684 |
with record_function("muon::momentum"):
|
| 685 |
+
active_params = [p for p in params if p.grad is not None]
|
| 686 |
+
if active_params:
|
| 687 |
+
# Ensure momentum buffers exist (avoid zeros_like when already present).
|
| 688 |
+
for p in active_params:
|
| 689 |
+
if "momentum_buffer" not in self.state[p]:
|
| 690 |
+
self.state[p]["momentum_buffer"] = torch.zeros_like(
|
| 691 |
+
p.grad)
|
| 692 |
+
|
| 693 |
+
# Extract local tensors for compiled batch function.
|
| 694 |
+
local_grads = [
|
| 695 |
+
p.grad._local_tensor
|
| 696 |
+
if isinstance(p.grad, DTensor) else p.grad
|
| 697 |
+
for p in active_params
|
| 698 |
+
]
|
| 699 |
+
local_bufs = [
|
| 700 |
+
self.state[p]["momentum_buffer"]._local_tensor
|
| 701 |
+
if isinstance(self.state[p]["momentum_buffer"], DTensor)
|
| 702 |
+
else self.state[p]["momentum_buffer"]
|
| 703 |
+
for p in active_params
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
# Wrap momentum as tensor for torch.compile.
|
| 707 |
+
batch_pre_ortho(local_grads, local_bufs,
|
| 708 |
+
torch.tensor(momentum), group["nesterov"])
|
| 709 |
+
|
| 710 |
+
# For non-nesterov, the result is the momentum buffer.
|
| 711 |
+
if not group["nesterov"]:
|
| 712 |
+
for p in active_params:
|
| 713 |
+
p.grad = self.state[p]["momentum_buffer"]
|
| 714 |
+
|
| 715 |
+
# Identify batched experts for deferred NS.
|
| 716 |
+
# Detection is cheap (condition checks only); actual NS compute is
|
| 717 |
+
# deferred so it can overlap with the first chunk's A2A gather.
|
| 718 |
+
deferred_expert_work = []
|
| 719 |
+
if self.expert_keys:
|
| 720 |
+
batched_expert_indices = []
|
| 721 |
+
for i, (n, p) in enumerate(zip(names, params)):
|
| 722 |
+
if not (is_expert_param(n, self.expert_keys)
|
| 723 |
+
and p.grad is not None):
|
| 724 |
continue
|
| 725 |
+
# Eligible: plain tensor, or DTensor with no non-dim-0 shards.
|
| 726 |
+
if isinstance(p.data, DTensor):
|
| 727 |
+
has_tp = any(
|
| 728 |
+
_is_shard(pl) and pl.dim != 0 for pl in p.placements)
|
| 729 |
+
if has_tp:
|
| 730 |
+
continue
|
| 731 |
+
batched_expert_indices.append(i)
|
| 732 |
+
|
| 733 |
+
if batched_expert_indices:
|
| 734 |
+
# Save refs for deferred NS; free grads from param list.
|
| 735 |
+
for i in batched_expert_indices:
|
| 736 |
+
p = params[i]
|
| 737 |
+
g = p.grad
|
| 738 |
+
local_g = (g._local_tensor
|
| 739 |
+
if isinstance(g, DTensor) else g)
|
| 740 |
+
local_data = (p.data._local_tensor if isinstance(
|
| 741 |
+
p.data, DTensor) else p.data)
|
| 742 |
+
deferred_expert_work.append((local_data, local_g))
|
| 743 |
+
p.grad = None
|
| 744 |
+
|
| 745 |
+
# Remove batched experts from lists before expansion.
|
| 746 |
+
keep = sorted(
|
| 747 |
+
set(range(len(params))) - set(batched_expert_indices))
|
| 748 |
+
names = [names[i] for i in keep]
|
| 749 |
+
params = [params[i] for i in keep]
|
| 750 |
+
|
| 751 |
+
def _run_deferred_expert_ns():
|
| 752 |
+
"""Execute deferred batched expert NS."""
|
| 753 |
+
if not deferred_expert_work:
|
| 754 |
+
return
|
| 755 |
+
with record_function("muon::batched_expert_ns"):
|
| 756 |
+
ns_steps = group["ns_steps"]
|
| 757 |
+
for local_data, local_g in deferred_expert_work:
|
| 758 |
+
u = zeropower_via_newtonschulz5_batched(
|
| 759 |
+
local_g.to(COMM_DTYPE), steps=ns_steps)
|
| 760 |
+
adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
|
| 761 |
+
local_data.mul_(1 - lr * weight_decay)
|
| 762 |
+
local_data.add_(u, alpha=-adjusted_lr)
|
| 763 |
|
| 764 |
# Expand expert params by splitting on dim 0.
|
| 765 |
+
logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
|
| 766 |
+
len(params), self.expert_keys)
|
| 767 |
+
if self.expert_keys:
|
| 768 |
+
cache_key = tuple(id(p) for p in params)
|
| 769 |
+
cache = self._expert_expand_cache.get(cache_key)
|
| 770 |
+
|
| 771 |
+
if cache is None:
|
| 772 |
+
# Cold path: full expansion + build cache metadata.
|
| 773 |
+
exp_names, exp_params = _expand_expert_params(
|
| 774 |
+
names, params, self.expert_keys)
|
| 775 |
+
|
| 776 |
+
# Build per-expert-group info for hot-path grad updates.
|
| 777 |
+
grad_info = []
|
| 778 |
+
exp_idx = 0
|
| 779 |
+
for orig_idx, (n, p) in enumerate(zip(names, params)):
|
| 780 |
+
if not is_expert_param(n, self.expert_keys):
|
| 781 |
+
exp_idx += 1
|
| 782 |
+
continue
|
| 783 |
+
|
| 784 |
+
is_dt = isinstance(p.data, DTensor)
|
| 785 |
+
num_experts = (p.to_local() if is_dt else p.data).shape[0]
|
| 786 |
+
|
| 787 |
+
# Detect TP mesh from the first expanded expert param.
|
| 788 |
+
tp_mesh = None
|
| 789 |
+
tp_pls = None
|
| 790 |
+
sample = exp_params[exp_idx]
|
| 791 |
+
if isinstance(sample.data, DTensor):
|
| 792 |
+
tp_mesh = sample.data.device_mesh
|
| 793 |
+
tp_pls = list(sample.data.placements)
|
| 794 |
+
|
| 795 |
+
grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
|
| 796 |
+
tp_mesh, tp_pls))
|
| 797 |
+
exp_idx += num_experts
|
| 798 |
+
|
| 799 |
+
self._expert_expand_cache[cache_key] = {
|
| 800 |
+
'names': exp_names,
|
| 801 |
+
'params': exp_params,
|
| 802 |
+
'grad_info': grad_info,
|
| 803 |
+
}
|
| 804 |
+
names, params = exp_names, exp_params
|
| 805 |
+
else:
|
| 806 |
+
# Hot path: reuse cached params, only update expert grads.
|
| 807 |
+
for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
|
| 808 |
+
tp_pls) in cache['grad_info']:
|
| 809 |
+
p = params[orig_idx]
|
| 810 |
+
g = p.grad
|
| 811 |
+
local_grad = (g.to_local()
|
| 812 |
+
if is_dt and isinstance(g, DTensor) else g)
|
| 813 |
+
for i in range(num_experts):
|
| 814 |
+
expert_p = cache['params'][exp_start + i]
|
| 815 |
+
sg = local_grad[i]
|
| 816 |
+
if tp_mesh is not None:
|
| 817 |
+
expert_p.grad = DTensor.from_local(
|
| 818 |
+
sg, device_mesh=tp_mesh, placements=tp_pls)
|
| 819 |
+
else:
|
| 820 |
+
expert_p.grad = sg
|
| 821 |
+
p.grad = None
|
| 822 |
+
|
| 823 |
+
names = cache['names']
|
| 824 |
+
params = cache['params']
|
| 825 |
+
else:
|
| 826 |
+
names, params = _expand_expert_params(names, params,
|
| 827 |
+
self.expert_keys)
|
| 828 |
+
logger.debug("[_step_muon] after expand: %d params", len(params))
|
| 829 |
|
| 830 |
param_dtensors = []
|
| 831 |
name_dtensors = []
|
|
|
|
| 833 |
param_tensors = []
|
| 834 |
name_tensors = []
|
| 835 |
|
| 836 |
+
# distributed_muon is a reference implementation for testing only.
|
| 837 |
+
# The parallel pipeline (all2all) path below is the production path.
|
|
|
|
| 838 |
if self.use_distributed_muon:
|
| 839 |
+
_run_deferred_expert_ns()
|
| 840 |
self.distributed_muon(names=names,
|
| 841 |
params=params,
|
| 842 |
group=group,
|
|
|
|
| 845 |
qk_logits=qk_logits)
|
| 846 |
return
|
| 847 |
|
|
|
|
|
|
|
| 848 |
for n, p in zip(names, params):
|
| 849 |
if p is None or p.grad is None:
|
| 850 |
continue
|
|
|
|
| 852 |
if all(
|
| 853 |
isinstance(placement, Replicate)
|
| 854 |
for placement in p.placements):
|
| 855 |
+
logger.debug(
|
| 856 |
+
"[route] %s → base (DTensor all-Replicate), "
|
| 857 |
+
"shape=%s, placements=%s", n, p.shape, p.placements)
|
| 858 |
param_tensors.append(p)
|
| 859 |
name_tensors.append(n)
|
|
|
|
|
|
|
|
|
|
| 860 |
else:
|
| 861 |
+
logger.debug(
|
| 862 |
+
"[route] %s → parallel (DTensor), shape=%s, "
|
| 863 |
+
"placements=%s, mesh=%s", n, p.shape, p.placements,
|
| 864 |
+
p.device_mesh.mesh_dim_names)
|
| 865 |
param_dtensors.append(p)
|
| 866 |
name_dtensors.append(n)
|
| 867 |
elif isinstance(p.data, torch.Tensor):
|
| 868 |
+
logger.debug("[route] %s → base (plain tensor), shape=%s", n,
|
| 869 |
+
p.data.shape)
|
| 870 |
param_tensors.append(p)
|
| 871 |
name_tensors.append(n)
|
| 872 |
else:
|
| 873 |
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 874 |
|
| 875 |
+
logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
|
| 876 |
+
f"{len(param_tensors)} Tensors → base")
|
|
|
|
| 877 |
|
| 878 |
def group_dtensors(dtensors, names):
|
| 879 |
# To support different placements, we group parameters by placements
|
|
|
|
| 889 |
p.device_mesh])][1].append(p)
|
| 890 |
return placement_to_params
|
| 891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
if len(param_dtensors) > 0:
|
| 893 |
if not dist.is_initialized():
|
| 894 |
raise RuntimeError(
|
|
|
|
| 896 |
)
|
| 897 |
|
| 898 |
dtensor_group = group_dtensors(param_dtensors, name_dtensors)
|
| 899 |
+
|
| 900 |
+
# Pre-launch the first chunk's A2A gather so that the NCCL
|
| 901 |
+
# communication overlaps with the (deferred) batched expert NS
|
| 902 |
+
# compute on the default CUDA stream.
|
| 903 |
+
prelaunch = None
|
| 904 |
+
if deferred_expert_work:
|
| 905 |
+
first_names, first_params = next(iter(dtensor_group.values()))
|
| 906 |
+
ordered, pts, rnk, csz = self._setup_parallel(
|
| 907 |
+
first_names, first_params, group, qk_logits)
|
| 908 |
+
first_chunk = ordered[:csz]
|
| 909 |
+
if first_chunk:
|
| 910 |
+
prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
|
| 911 |
+
group["none_grad"])
|
| 912 |
+
|
| 913 |
+
_run_deferred_expert_ns()
|
| 914 |
+
|
| 915 |
+
first_group = True
|
| 916 |
for _, (names, params) in dtensor_group.items():
|
| 917 |
+
pg = prelaunch if first_group else None
|
| 918 |
+
first_group = False
|
| 919 |
self.parallel(
|
| 920 |
names,
|
| 921 |
params,
|
|
|
|
| 923 |
lr=lr,
|
| 924 |
weight_decay=weight_decay,
|
| 925 |
qk_logits=qk_logits,
|
| 926 |
+
prelaunch_gather=pg,
|
| 927 |
)
|
| 928 |
+
else:
|
| 929 |
+
_run_deferred_expert_ns()
|
| 930 |
|
| 931 |
if len(param_tensors) > 0:
|
| 932 |
self.base(
|
|
|
|
| 955 |
with torch.enable_grad():
|
| 956 |
loss = closure()
|
| 957 |
|
| 958 |
+
logger.debug("[Muon.step] expert_keys=%s, %d param groups",
|
| 959 |
+
self.expert_keys, len(self.param_groups))
|
| 960 |
+
|
| 961 |
+
for i, group in enumerate(self.param_groups):
|
| 962 |
if group["use_muon"]:
|
| 963 |
+
logger.debug("[Muon.step] group %d: use_muon=True, %d params",
|
| 964 |
+
i, len(group["params"]))
|
| 965 |
self._step_muon(group, qk_logits=qk_logits)
|
| 966 |
else:
|
| 967 |
+
logger.debug(
|
| 968 |
+
"[Muon.step] group %d: use_muon=False (AdamW), %d params",
|
| 969 |
+
i, len(group["params"]))
|
| 970 |
step_adamw(self.state, group)
|
| 971 |
|
| 972 |
return loss
|