Kernels
wyldecat Claude Opus 4.6 commited on
Commit
0f37d63
·
1 Parent(s): 2816b64

Muon optimizer: expert batching, parallel caching, A2A overlap [skip-build]

Browse files

- Batched expert NS path for plain-tensor MoE params (skip expansion)
- Expert expansion cache to eliminate per-step detach overhead
- _setup_parallel() extraction for parallel metadata reuse
- Prelaunch first chunk A2A gather to overlap with expert NS compute
- Profiler annotations and clarify distributed_muon as test-only

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +475 -98
torch-ext/optimizer/muon.py CHANGED
@@ -10,14 +10,15 @@ from torch.profiler import record_function
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, is_expert_param, update_g,
15
- update_p)
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
- _zeropower_via_newtonschulz5)
20
- from .pipeline import muon_chunk_pipeline
 
 
21
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
22
 
23
  logger = logging.getLogger(__name__)
@@ -49,6 +50,18 @@ def _expand_expert_params(names, params, expert_keys):
49
  is_expert = is_expert_param(n, expert_keys)
50
  is_dtensor = isinstance(p.data, DTensor)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if not is_expert:
53
  assert p.data.ndim <= 2, (
54
  f"Param {n} has ndim={p.data.ndim} but does not match "
@@ -169,7 +182,6 @@ class Muon(torch.optim.Optimizer):
169
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
170
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
171
  For testing purpose only.
172
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
173
  expert_keys: List of strings to identify expert-parallel parameters.
174
  If any key appears in a parameter's name, its outermost
175
  dimension is treated as the expert dimension and expanded
@@ -194,7 +206,6 @@ class Muon(torch.optim.Optimizer):
194
  warmup_step=5,
195
  chunk_size=-1,
196
  use_distributed_muon=False,
197
- small_param_numel_threshold=65536,
198
  expert_keys=None):
199
  defaults = dict(
200
  lr=lr,
@@ -229,8 +240,9 @@ class Muon(torch.optim.Optimizer):
229
  self.warmup_step = warmup_step
230
  self.chunk_size = chunk_size
231
  self.use_distributed_muon = use_distributed_muon
232
- self.small_param_numel_threshold = small_param_numel_threshold
233
  self.expert_keys = expert_keys
 
 
234
 
235
  def _calc_flops(self, G, steps):
236
  assert len(G.shape) == 2
@@ -334,8 +346,8 @@ class Muon(torch.optim.Optimizer):
334
  if g is None:
335
  continue
336
 
337
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
338
- steps=group["ns_steps"])
339
 
340
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
341
  update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -356,52 +368,269 @@ class Muon(torch.optim.Optimizer):
356
  weight_decay: float,
357
  qk_logits: list[torch.Tensor | DTensor] | None,
358
  ):
359
- """ Implementation of Distributed Muon by Liu et al. """
360
 
361
- # Momentum is already applied by _step_muon before this method.
362
- for n, p in zip(names, params):
363
- g = p.grad
364
- if g is None:
365
- continue
366
-
367
- # Gather G
368
- if isinstance(p.data, DTensor):
369
- g_full = g.full_tensor()
370
- p_full = p.data.full_tensor()
371
- else:
372
- g_full = g
373
- p_full = p
374
-
375
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
376
- steps=group["ns_steps"])
377
-
378
- adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
379
- update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
380
 
381
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- scales_full = compute_scales(
384
- p_full, qk_clip_state) if qk_clip_state is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- if scales_full is not None:
387
- qk_clip(p_full, scales_full, qk_clip_state.head_dim)
 
 
388
 
389
- if isinstance(p.data, DTensor):
390
- ndims = len(p.device_mesh.mesh.shape)
391
- p_replicate = DTensor.from_local(
392
- p_full,
393
- device_mesh=p.device_mesh,
394
- placements=[Replicate() for _ in range(ndims)],
395
- )
396
 
397
- p_sharded = p_replicate.redistribute(
398
- device_mesh=p.device_mesh,
399
- placements=p.placements,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
401
 
402
- p.copy_(p_sharded)
403
 
404
- def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
405
  """
406
  Perform a parallel optimization step using Muon.
407
 
@@ -410,31 +639,23 @@ class Muon(torch.optim.Optimizer):
410
  interleaves multiple chunks so that communication and computation
411
  overlap across chunks (the same overlap previously achieved by the
412
  warmup + main-loop index scheduling).
 
 
 
 
413
  """
414
 
415
  # Momentum is already applied by _step_muon before this method.
416
 
417
- param_to_state, ordered_params = self.init_state_and_assign_params(
418
- names, params, group, qk_logits)
419
-
420
- # Compute local rank for this group's shard process group.
421
- shard_pg = param_to_state[id(ordered_params[0])].process_group
422
- rank = dist.get_rank(group=shard_pg)
423
-
424
- if self.chunk_size == -1:
425
- shard_ranks = dist.get_world_size(param_to_state[id(
426
- ordered_params[0])].process_group)
427
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
428
- elif self.chunk_size > 0:
429
- chunk_size = self.chunk_size
430
- else:
431
- raise ValueError("chunk_size must be -1 or a positive integer.")
432
 
433
  def pipelines():
 
434
  for start in range(0, len(ordered_params), chunk_size):
435
  chunk = ordered_params[start:start + chunk_size]
436
  if chunk:
437
- yield muon_chunk_pipeline(
438
  params=chunk,
439
  param_to_state=param_to_state,
440
  rank=rank,
@@ -443,9 +664,11 @@ class Muon(torch.optim.Optimizer):
443
  weight_decay=weight_decay,
444
  none_grad=group["none_grad"],
445
  )
 
 
 
 
446
 
447
- with record_function("muon::barrier"):
448
- dist.barrier()
449
  with record_function("muon::pipeline"):
450
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
451
 
@@ -457,16 +680,152 @@ class Muon(torch.optim.Optimizer):
457
  names = group["names"]
458
 
459
  # Apply momentum to all params before routing/expansion.
 
460
  with record_function("muon::momentum"):
461
- for n, p in zip(names, params):
462
- g = p.grad
463
- if g is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  continue
465
- g = update_g(self.state, p, g, group, momentum)
466
- p.grad = g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  # Expand expert params by splitting on dim 0.
469
- names, params = _expand_expert_params(names, params, self.expert_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  param_dtensors = []
472
  name_dtensors = []
@@ -474,10 +833,10 @@ class Muon(torch.optim.Optimizer):
474
  param_tensors = []
475
  name_tensors = []
476
 
477
- param_dtensors_small = []
478
- name_dtensors_small = []
479
-
480
  if self.use_distributed_muon:
 
481
  self.distributed_muon(names=names,
482
  params=params,
483
  group=group,
@@ -486,8 +845,6 @@ class Muon(torch.optim.Optimizer):
486
  qk_logits=qk_logits)
487
  return
488
 
489
- # For simplicity, we use distributed Muon for small parameters
490
- # whose number of elements is below a threshold.
491
  for n, p in zip(names, params):
492
  if p is None or p.grad is None:
493
  continue
@@ -495,23 +852,28 @@ class Muon(torch.optim.Optimizer):
495
  if all(
496
  isinstance(placement, Replicate)
497
  for placement in p.placements):
 
 
 
498
  param_tensors.append(p)
499
  name_tensors.append(n)
500
- elif p.data.numel() <= self.small_param_numel_threshold:
501
- param_dtensors_small.append(p)
502
- name_dtensors_small.append(n)
503
  else:
 
 
 
 
504
  param_dtensors.append(p)
505
  name_dtensors.append(n)
506
  elif isinstance(p.data, torch.Tensor):
 
 
507
  param_tensors.append(p)
508
  name_tensors.append(n)
509
  else:
510
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
511
 
512
- logger.debug(
513
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
514
- f"{len(param_dtensors_small)} Small DTensors")
515
 
516
  def group_dtensors(dtensors, names):
517
  # To support different placements, we group parameters by placements
@@ -527,21 +889,6 @@ class Muon(torch.optim.Optimizer):
527
  p.device_mesh])][1].append(p)
528
  return placement_to_params
529
 
530
- if len(param_dtensors_small) > 0:
531
- if not dist.is_initialized():
532
- raise RuntimeError(
533
- "Parallel Muon requires torch.distributed to be initialized."
534
- )
535
-
536
- self.distributed_muon(
537
- params=param_dtensors_small,
538
- names=name_dtensors_small,
539
- group=group,
540
- lr=lr,
541
- weight_decay=weight_decay,
542
- qk_logits=qk_logits,
543
- )
544
-
545
  if len(param_dtensors) > 0:
546
  if not dist.is_initialized():
547
  raise RuntimeError(
@@ -549,7 +896,26 @@ class Muon(torch.optim.Optimizer):
549
  )
550
 
551
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  for _, (names, params) in dtensor_group.items():
 
 
553
  self.parallel(
554
  names,
555
  params,
@@ -557,7 +923,10 @@ class Muon(torch.optim.Optimizer):
557
  lr=lr,
558
  weight_decay=weight_decay,
559
  qk_logits=qk_logits,
 
560
  )
 
 
561
 
562
  if len(param_tensors) > 0:
563
  self.base(
@@ -586,10 +955,18 @@ class Muon(torch.optim.Optimizer):
586
  with torch.enable_grad():
587
  loss = closure()
588
 
589
- for group in self.param_groups:
 
 
 
590
  if group["use_muon"]:
 
 
591
  self._step_muon(group, qk_logits=qk_logits)
592
  else:
 
 
 
593
  step_adamw(self.state, group)
594
 
595
  return loss
 
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
+ get_default_muon_param_groups, is_expert_param, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
+ _zeropower_via_newtonschulz5,
19
+ zeropower_via_newtonschulz5,
20
+ zeropower_via_newtonschulz5_batched)
21
+ from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
22
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
23
 
24
  logger = logging.getLogger(__name__)
 
50
  is_expert = is_expert_param(n, expert_keys)
51
  is_dtensor = isinstance(p.data, DTensor)
52
 
53
+ if is_expert:
54
+ if is_dtensor:
55
+ logger.debug(
56
+ "[expand_expert] %s: expert DTensor, shape=%s, "
57
+ "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
58
+ p.placements, p.device_mesh.mesh_dim_names,
59
+ p.to_local().shape)
60
+ else:
61
+ logger.debug(
62
+ "[expand_expert] %s: expert plain tensor, shape=%s", n,
63
+ p.data.shape)
64
+
65
  if not is_expert:
66
  assert p.data.ndim <= 2, (
67
  f"Param {n} has ndim={p.data.ndim} but does not match "
 
182
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
183
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
184
  For testing purpose only.
 
185
  expert_keys: List of strings to identify expert-parallel parameters.
186
  If any key appears in a parameter's name, its outermost
187
  dimension is treated as the expert dimension and expanded
 
206
  warmup_step=5,
207
  chunk_size=-1,
208
  use_distributed_muon=False,
 
209
  expert_keys=None):
210
  defaults = dict(
211
  lr=lr,
 
240
  self.warmup_step = warmup_step
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
 
243
  self.expert_keys = expert_keys
244
+ self._parallel_cache: dict[tuple[str, ...], dict] = {}
245
+ self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
246
 
247
  def _calc_flops(self, G, steps):
248
  assert len(G.shape) == 2
 
346
  if g is None:
347
  continue
348
 
349
+ u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
350
+ steps=group["ns_steps"])
351
 
352
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
353
  update_p(p, u, lr, adjusted_lr, weight_decay)
 
368
  weight_decay: float,
369
  qk_logits: list[torch.Tensor | DTensor] | None,
370
  ):
371
+ """Batched Distributed Muon for testing/correctness verification only.
372
 
373
+ Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
374
+ the full grad, then slices back to local shards. This is simpler but
375
+ slower than the parallel pipeline (all2all) path, so it serves as a
376
+ reference implementation for verifying correctness.
377
+ """
378
+ with record_function("distributed_muon"):
379
+ # Momentum is already applied by _step_muon before this method.
380
+ ns_steps = group["ns_steps"]
 
 
 
 
 
 
 
 
 
 
 
381
 
382
+ # Separate plain tensors (no communication) from DTensors.
383
+ plain_names, plain_params = [], []
384
+ dtensor_names, dtensor_params = [], []
385
+ for n, p in zip(names, params):
386
+ if p.grad is None:
387
+ continue
388
+ if isinstance(p.data, DTensor):
389
+ dtensor_names.append(n)
390
+ dtensor_params.append(p)
391
+ else:
392
+ plain_names.append(n)
393
+ plain_params.append(p)
394
+
395
+ # Process plain tensors per-param (no communication).
396
+ for n, p in zip(plain_names, plain_params):
397
+ u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
398
+ steps=ns_steps)
399
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
400
+ update_p(p, u, lr, adjusted_lr, weight_decay)
401
+
402
+ qk_clip_state = get_qk_clip_info(self.clip_config, n,
403
+ qk_logits)
404
+ scales_full = compute_scales(
405
+ p, qk_clip_state) if qk_clip_state is not None else None
406
+ if scales_full is not None:
407
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
408
+
409
+ if not dtensor_params:
410
+ return
411
+
412
+ # Group DTensors by (placements, mesh) for batched all-gather.
413
+ placement_groups: dict[tuple,
414
+ tuple[list,
415
+ list]] = defaultdict(lambda: ([], []))
416
+ for n, p in zip(dtensor_names, dtensor_params):
417
+ key = (p.placements, p.device_mesh)
418
+ placement_groups[key][0].append(n)
419
+ placement_groups[key][1].append(p)
420
+
421
+ logger.info(
422
+ "distributed_muon: %d placement groups, %d total dtensors",
423
+ len(placement_groups), len(dtensor_params))
424
+
425
+ for (placements, mesh), (grp_names,
426
+ grp_params) in placement_groups.items():
427
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
428
+ placements, mesh)
429
+ rank = dist.get_rank(shard_pg)
430
+ world_size = dist.get_world_size(shard_pg)
431
+
432
+ logger.info(" group: %d params, placements=%s, world_size=%d",
433
+ len(grp_params), placements, world_size)
434
+
435
+ # Separate params that can be batched (all shard dims evenly
436
+ # divisible) from those needing per-param full_tensor
437
+ # (e.g. MoE gate weights with fewer rows than shard ranks).
438
+ # all_gather_into_tensor requires equal buffer sizes across
439
+ # ranks, so uneven splits must use DTensor full_tensor().
440
+ batch_names, batch_params = [], []
441
+ single_names, single_params = [], []
442
+ for n, p in zip(grp_names, grp_params):
443
+ even = all(p.shape[pl.dim] %
444
+ shard_mesh.mesh.shape[dim_idx] == 0
445
+ for dim_idx, pl in enumerate(shard_placements))
446
+ if even:
447
+ batch_names.append(n)
448
+ batch_params.append(p)
449
+ else:
450
+ single_names.append(n)
451
+ single_params.append(p)
452
+
453
+ # Process uneven-split params per-param via full_tensor().
454
+ for n, p in zip(single_names, single_params):
455
+ with record_function("distributed_muon::newton_schulz"):
456
+ g_full = p.grad.full_tensor().to(COMM_DTYPE)
457
+ u_full = _zeropower_via_newtonschulz5(g_full,
458
+ steps=ns_steps)
459
+ del g_full
460
+ with record_function("distributed_muon::update"):
461
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
462
+ p._local_tensor.mul_(1 - lr * weight_decay)
463
+ local_indices = get_slices_of_dtensor(
464
+ p, rank, shard_mesh, shard_placements)
465
+ u_local = u_full[local_indices]
466
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
467
+ del u_full
468
+
469
+ qk_clip_state = get_qk_clip_info(
470
+ self.clip_config, n, qk_logits)
471
+ scales_full = compute_scales(
472
+ p, qk_clip_state
473
+ ) if qk_clip_state is not None else None
474
+ if scales_full is not None:
475
+ ratio = p.shape[0] // scales_full.shape[0]
476
+ idx0 = local_indices[0]
477
+ if isinstance(idx0, slice):
478
+ start = idx0.start or 0
479
+ idx0 = torch.arange(start,
480
+ idx0.stop,
481
+ device=scales_full.device)
482
+ row_scales = scales_full[idx0 // ratio]
483
+ p._local_tensor.mul_(row_scales.view(-1, 1))
484
+
485
+ if not batch_params:
486
+ continue
487
 
488
+ logger.info(" batched=%d, single=%d", len(batch_params),
489
+ len(single_params))
490
+
491
+ # Concat all local grad shards into a single flat buffer.
492
+ with record_function("distributed_muon::gather"):
493
+ grad_locals = [
494
+ p.grad.to_local().to(COMM_DTYPE).flatten()
495
+ for p in batch_params
496
+ ]
497
+ numels = [g.numel() for g in grad_locals]
498
+ grad_concat = torch.cat(grad_locals)
499
+ del grad_locals
500
+
501
+ # Single all-gather (replaces N separate full_tensor).
502
+ grad_gathered = torch.empty(
503
+ grad_concat.numel() * world_size,
504
+ dtype=COMM_DTYPE,
505
+ device="cuda",
506
+ )
507
+ dist.all_gather_into_tensor(grad_gathered,
508
+ grad_concat,
509
+ group=shard_pg)
510
+
511
+ total_numel = grad_concat.numel()
512
+ del grad_concat
513
+
514
+ # Precompute per-param offsets within the concat buffer.
515
+ offsets = []
516
+ off = 0
517
+ for ne in numels:
518
+ offsets.append(off)
519
+ off += ne
520
+
521
+ # Per-param: reconstruct full grad → NS → local update.
522
+ for i, (n, p) in enumerate(zip(batch_names, batch_params)):
523
+ with record_function("distributed_muon::newton_schulz"):
524
+ g_full = torch.empty(p.shape,
525
+ dtype=COMM_DTYPE,
526
+ device="cuda")
527
+ for r in range(world_size):
528
+ r_start = r * total_numel + offsets[i]
529
+ shard = grad_gathered[r_start:r_start + numels[i]]
530
+ indices = get_slices_of_dtensor(
531
+ p, r, shard_mesh, shard_placements)
532
+ g_full[indices] = shard.reshape(
533
+ g_full[indices].shape)
534
+
535
+ u_full = _zeropower_via_newtonschulz5(g_full,
536
+ steps=ns_steps)
537
+ del g_full
538
+
539
+ with record_function("distributed_muon::update"):
540
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
541
+ p._local_tensor.mul_(1 - lr * weight_decay)
542
+ local_indices = get_slices_of_dtensor(
543
+ p, rank, shard_mesh, shard_placements)
544
+ u_local = u_full[local_indices]
545
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
546
+ del u_full
547
+
548
+ qk_clip_state = get_qk_clip_info(
549
+ self.clip_config, n, qk_logits)
550
+ scales_full = compute_scales(
551
+ p, qk_clip_state
552
+ ) if qk_clip_state is not None else None
553
+ if scales_full is not None:
554
+ ratio = p.shape[0] // scales_full.shape[0]
555
+ idx0 = local_indices[0]
556
+ if isinstance(idx0, slice):
557
+ start = idx0.start or 0
558
+ idx0 = torch.arange(start,
559
+ idx0.stop,
560
+ device=scales_full.device)
561
+ row_scales = scales_full[idx0 // ratio]
562
+ p._local_tensor.mul_(row_scales.view(-1, 1))
563
+
564
+ def _setup_parallel(self, names, params, group, qk_logits):
565
+ """Compute (or retrieve cached) parallel pipeline metadata.
566
+
567
+ Returns:
568
+ (ordered_params, param_to_state, rank, chunk_size)
569
+ """
570
+ cache_key = tuple(names)
571
 
572
+ if cache_key not in self._parallel_cache:
573
+ # First call: compute metadata and populate cache.
574
+ param_to_state, ordered_params = self.init_state_and_assign_params(
575
+ names, params, group, qk_logits)
576
 
577
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
578
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
579
 
580
+ if self.chunk_size == -1:
581
+ shard_ranks = dist.get_world_size(shard_pg)
582
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
583
+ elif self.chunk_size > 0:
584
+ chunk_size = self.chunk_size
585
+ else:
586
+ raise ValueError(
587
+ "chunk_size must be -1 or a positive integer.")
588
+
589
+ ordered_names = [
590
+ param_to_state[id(p)].name for p in ordered_params
591
+ ]
592
+ name_to_state = {
593
+ param_to_state[id(p)].name: param_to_state[id(p)]
594
+ for p in ordered_params
595
+ }
596
+ self._parallel_cache[cache_key] = {
597
+ 'ordered_names': ordered_names,
598
+ 'name_to_state': name_to_state,
599
+ 'rank': rank,
600
+ 'chunk_size': chunk_size,
601
+ }
602
+ else:
603
+ # Cached path: rebuild param_to_state with current id(p) keys.
604
+ cache = self._parallel_cache[cache_key]
605
+ rank = cache['rank']
606
+ chunk_size = cache['chunk_size']
607
+
608
+ name_to_param = dict(zip(names, params))
609
+ ordered_params = [name_to_param[n] for n in cache['ordered_names']]
610
+
611
+ param_to_state = {}
612
+ for p, n in zip(ordered_params, cache['ordered_names']):
613
+ cached_state = cache['name_to_state'][n]
614
+ param_to_state[id(p)] = _muon_state(
615
+ worker_rank=cached_state.worker_rank,
616
+ process_group=cached_state.process_group,
617
+ rank_indices=cached_state.rank_indices,
618
+ rank_numels=cached_state.rank_numels,
619
+ name=n,
620
+ qk_clip_state=get_qk_clip_info(self.clip_config, n,
621
+ qk_logits),
622
  )
623
 
624
+ return ordered_params, param_to_state, rank, chunk_size
625
 
626
+ def parallel(self,
627
+ names,
628
+ params,
629
+ group,
630
+ lr,
631
+ weight_decay,
632
+ qk_logits,
633
+ prelaunch_gather=None):
634
  """
635
  Perform a parallel optimization step using Muon.
636
 
 
639
  interleaves multiple chunks so that communication and computation
640
  overlap across chunks (the same overlap previously achieved by the
641
  warmup + main-loop index scheduling).
642
+
643
+ If ``prelaunch_gather`` is provided, it is passed to the first
644
+ chunk's generator to skip re-launching the already in-flight
645
+ A2A gather.
646
  """
647
 
648
  # Momentum is already applied by _step_muon before this method.
649
 
650
+ ordered_params, param_to_state, rank, chunk_size = (
651
+ self._setup_parallel(names, params, group, qk_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
  def pipelines():
654
+ first = True
655
  for start in range(0, len(ordered_params), chunk_size):
656
  chunk = ordered_params[start:start + chunk_size]
657
  if chunk:
658
+ kwargs = dict(
659
  params=chunk,
660
  param_to_state=param_to_state,
661
  rank=rank,
 
664
  weight_decay=weight_decay,
665
  none_grad=group["none_grad"],
666
  )
667
+ if first and prelaunch_gather is not None:
668
+ kwargs['prelaunch_gather'] = prelaunch_gather
669
+ first = False
670
+ yield muon_chunk_pipeline(**kwargs)
671
 
 
 
672
  with record_function("muon::pipeline"):
673
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
674
 
 
680
  names = group["names"]
681
 
682
  # Apply momentum to all params before routing/expansion.
683
+ # Batched using _foreach_* ops (compiled, fullgraph=True).
684
  with record_function("muon::momentum"):
685
+ active_params = [p for p in params if p.grad is not None]
686
+ if active_params:
687
+ # Ensure momentum buffers exist (avoid zeros_like when already present).
688
+ for p in active_params:
689
+ if "momentum_buffer" not in self.state[p]:
690
+ self.state[p]["momentum_buffer"] = torch.zeros_like(
691
+ p.grad)
692
+
693
+ # Extract local tensors for compiled batch function.
694
+ local_grads = [
695
+ p.grad._local_tensor
696
+ if isinstance(p.grad, DTensor) else p.grad
697
+ for p in active_params
698
+ ]
699
+ local_bufs = [
700
+ self.state[p]["momentum_buffer"]._local_tensor
701
+ if isinstance(self.state[p]["momentum_buffer"], DTensor)
702
+ else self.state[p]["momentum_buffer"]
703
+ for p in active_params
704
+ ]
705
+
706
+ # Wrap momentum as tensor for torch.compile.
707
+ batch_pre_ortho(local_grads, local_bufs,
708
+ torch.tensor(momentum), group["nesterov"])
709
+
710
+ # For non-nesterov, the result is the momentum buffer.
711
+ if not group["nesterov"]:
712
+ for p in active_params:
713
+ p.grad = self.state[p]["momentum_buffer"]
714
+
715
+ # Identify batched experts for deferred NS.
716
+ # Detection is cheap (condition checks only); actual NS compute is
717
+ # deferred so it can overlap with the first chunk's A2A gather.
718
+ deferred_expert_work = []
719
+ if self.expert_keys:
720
+ batched_expert_indices = []
721
+ for i, (n, p) in enumerate(zip(names, params)):
722
+ if not (is_expert_param(n, self.expert_keys)
723
+ and p.grad is not None):
724
  continue
725
+ # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
726
+ if isinstance(p.data, DTensor):
727
+ has_tp = any(
728
+ _is_shard(pl) and pl.dim != 0 for pl in p.placements)
729
+ if has_tp:
730
+ continue
731
+ batched_expert_indices.append(i)
732
+
733
+ if batched_expert_indices:
734
+ # Save refs for deferred NS; free grads from param list.
735
+ for i in batched_expert_indices:
736
+ p = params[i]
737
+ g = p.grad
738
+ local_g = (g._local_tensor
739
+ if isinstance(g, DTensor) else g)
740
+ local_data = (p.data._local_tensor if isinstance(
741
+ p.data, DTensor) else p.data)
742
+ deferred_expert_work.append((local_data, local_g))
743
+ p.grad = None
744
+
745
+ # Remove batched experts from lists before expansion.
746
+ keep = sorted(
747
+ set(range(len(params))) - set(batched_expert_indices))
748
+ names = [names[i] for i in keep]
749
+ params = [params[i] for i in keep]
750
+
751
+ def _run_deferred_expert_ns():
752
+ """Execute deferred batched expert NS."""
753
+ if not deferred_expert_work:
754
+ return
755
+ with record_function("muon::batched_expert_ns"):
756
+ ns_steps = group["ns_steps"]
757
+ for local_data, local_g in deferred_expert_work:
758
+ u = zeropower_via_newtonschulz5_batched(
759
+ local_g.to(COMM_DTYPE), steps=ns_steps)
760
+ adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
761
+ local_data.mul_(1 - lr * weight_decay)
762
+ local_data.add_(u, alpha=-adjusted_lr)
763
 
764
  # Expand expert params by splitting on dim 0.
765
+ logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
766
+ len(params), self.expert_keys)
767
+ if self.expert_keys:
768
+ cache_key = tuple(id(p) for p in params)
769
+ cache = self._expert_expand_cache.get(cache_key)
770
+
771
+ if cache is None:
772
+ # Cold path: full expansion + build cache metadata.
773
+ exp_names, exp_params = _expand_expert_params(
774
+ names, params, self.expert_keys)
775
+
776
+ # Build per-expert-group info for hot-path grad updates.
777
+ grad_info = []
778
+ exp_idx = 0
779
+ for orig_idx, (n, p) in enumerate(zip(names, params)):
780
+ if not is_expert_param(n, self.expert_keys):
781
+ exp_idx += 1
782
+ continue
783
+
784
+ is_dt = isinstance(p.data, DTensor)
785
+ num_experts = (p.to_local() if is_dt else p.data).shape[0]
786
+
787
+ # Detect TP mesh from the first expanded expert param.
788
+ tp_mesh = None
789
+ tp_pls = None
790
+ sample = exp_params[exp_idx]
791
+ if isinstance(sample.data, DTensor):
792
+ tp_mesh = sample.data.device_mesh
793
+ tp_pls = list(sample.data.placements)
794
+
795
+ grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
796
+ tp_mesh, tp_pls))
797
+ exp_idx += num_experts
798
+
799
+ self._expert_expand_cache[cache_key] = {
800
+ 'names': exp_names,
801
+ 'params': exp_params,
802
+ 'grad_info': grad_info,
803
+ }
804
+ names, params = exp_names, exp_params
805
+ else:
806
+ # Hot path: reuse cached params, only update expert grads.
807
+ for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
808
+ tp_pls) in cache['grad_info']:
809
+ p = params[orig_idx]
810
+ g = p.grad
811
+ local_grad = (g.to_local()
812
+ if is_dt and isinstance(g, DTensor) else g)
813
+ for i in range(num_experts):
814
+ expert_p = cache['params'][exp_start + i]
815
+ sg = local_grad[i]
816
+ if tp_mesh is not None:
817
+ expert_p.grad = DTensor.from_local(
818
+ sg, device_mesh=tp_mesh, placements=tp_pls)
819
+ else:
820
+ expert_p.grad = sg
821
+ p.grad = None
822
+
823
+ names = cache['names']
824
+ params = cache['params']
825
+ else:
826
+ names, params = _expand_expert_params(names, params,
827
+ self.expert_keys)
828
+ logger.debug("[_step_muon] after expand: %d params", len(params))
829
 
830
  param_dtensors = []
831
  name_dtensors = []
 
833
  param_tensors = []
834
  name_tensors = []
835
 
836
+ # distributed_muon is a reference implementation for testing only.
837
+ # The parallel pipeline (all2all) path below is the production path.
 
838
  if self.use_distributed_muon:
839
+ _run_deferred_expert_ns()
840
  self.distributed_muon(names=names,
841
  params=params,
842
  group=group,
 
845
  qk_logits=qk_logits)
846
  return
847
 
 
 
848
  for n, p in zip(names, params):
849
  if p is None or p.grad is None:
850
  continue
 
852
  if all(
853
  isinstance(placement, Replicate)
854
  for placement in p.placements):
855
+ logger.debug(
856
+ "[route] %s → base (DTensor all-Replicate), "
857
+ "shape=%s, placements=%s", n, p.shape, p.placements)
858
  param_tensors.append(p)
859
  name_tensors.append(n)
 
 
 
860
  else:
861
+ logger.debug(
862
+ "[route] %s → parallel (DTensor), shape=%s, "
863
+ "placements=%s, mesh=%s", n, p.shape, p.placements,
864
+ p.device_mesh.mesh_dim_names)
865
  param_dtensors.append(p)
866
  name_dtensors.append(n)
867
  elif isinstance(p.data, torch.Tensor):
868
+ logger.debug("[route] %s → base (plain tensor), shape=%s", n,
869
+ p.data.shape)
870
  param_tensors.append(p)
871
  name_tensors.append(n)
872
  else:
873
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
874
 
875
+ logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
876
+ f"{len(param_tensors)} Tensors → base")
 
877
 
878
  def group_dtensors(dtensors, names):
879
  # To support different placements, we group parameters by placements
 
889
  p.device_mesh])][1].append(p)
890
  return placement_to_params
891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  if len(param_dtensors) > 0:
893
  if not dist.is_initialized():
894
  raise RuntimeError(
 
896
  )
897
 
898
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
899
+
900
+ # Pre-launch the first chunk's A2A gather so that the NCCL
901
+ # communication overlaps with the (deferred) batched expert NS
902
+ # compute on the default CUDA stream.
903
+ prelaunch = None
904
+ if deferred_expert_work:
905
+ first_names, first_params = next(iter(dtensor_group.values()))
906
+ ordered, pts, rnk, csz = self._setup_parallel(
907
+ first_names, first_params, group, qk_logits)
908
+ first_chunk = ordered[:csz]
909
+ if first_chunk:
910
+ prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
911
+ group["none_grad"])
912
+
913
+ _run_deferred_expert_ns()
914
+
915
+ first_group = True
916
  for _, (names, params) in dtensor_group.items():
917
+ pg = prelaunch if first_group else None
918
+ first_group = False
919
  self.parallel(
920
  names,
921
  params,
 
923
  lr=lr,
924
  weight_decay=weight_decay,
925
  qk_logits=qk_logits,
926
+ prelaunch_gather=pg,
927
  )
928
+ else:
929
+ _run_deferred_expert_ns()
930
 
931
  if len(param_tensors) > 0:
932
  self.base(
 
955
  with torch.enable_grad():
956
  loss = closure()
957
 
958
+ logger.debug("[Muon.step] expert_keys=%s, %d param groups",
959
+ self.expert_keys, len(self.param_groups))
960
+
961
+ for i, group in enumerate(self.param_groups):
962
  if group["use_muon"]:
963
+ logger.debug("[Muon.step] group %d: use_muon=True, %d params",
964
+ i, len(group["params"]))
965
  self._step_muon(group, qk_logits=qk_logits)
966
  else:
967
+ logger.debug(
968
+ "[Muon.step] group %d: use_muon=False (AdamW), %d params",
969
+ i, len(group["params"]))
970
  step_adamw(self.state, group)
971
 
972
  return loss