Kernels
wyldecat github-actions[bot] commited on
Commit
33929c0
·
unverified ·
1 Parent(s): ae32572

Refactor pipeline to async generator pattern (#16)

Browse files

* Refactor muon.py into modules with async generator pipeline

* Add MoE expert weight support with EP+FSDP tests

* Add built binary [skip-build]

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CLAUDE.md +108 -0
  2. README.md +6 -0
  3. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  4. build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} +1 -1
  5. build/torch210-cxx11-cu126-x86_64-linux/adamw.py +154 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/async_utils.py +77 -0
  7. build/torch210-cxx11-cu126-x86_64-linux/core.py +116 -0
  8. build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +174 -115
  9. build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +0 -7
  10. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +3 -1
  11. build/torch210-cxx11-cu126-x86_64-linux/muon.py +196 -870
  12. build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +50 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/pipeline.py +390 -0
  14. build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +129 -0
  15. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  16. build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} +1 -1
  17. build/torch210-cxx11-cu128-x86_64-linux/adamw.py +154 -0
  18. build/torch210-cxx11-cu128-x86_64-linux/async_utils.py +77 -0
  19. build/torch210-cxx11-cu128-x86_64-linux/core.py +116 -0
  20. build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +174 -115
  21. build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +0 -7
  22. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +3 -1
  23. build/torch210-cxx11-cu128-x86_64-linux/muon.py +196 -870
  24. build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +50 -0
  25. build/torch210-cxx11-cu128-x86_64-linux/pipeline.py +390 -0
  26. build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +129 -0
  27. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  28. build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} +1 -1
  29. build/torch210-cxx11-cu130-x86_64-linux/adamw.py +154 -0
  30. build/torch210-cxx11-cu130-x86_64-linux/async_utils.py +77 -0
  31. build/torch210-cxx11-cu130-x86_64-linux/core.py +116 -0
  32. build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +174 -115
  33. build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +0 -7
  34. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +3 -1
  35. build/torch210-cxx11-cu130-x86_64-linux/muon.py +196 -870
  36. build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +50 -0
  37. build/torch210-cxx11-cu130-x86_64-linux/pipeline.py +390 -0
  38. build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +129 -0
  39. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
  40. build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} +1 -1
  41. build/torch210-cxx11-rocm70-x86_64-linux/adamw.py +154 -0
  42. build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py +77 -0
  43. build/torch210-cxx11-rocm70-x86_64-linux/core.py +116 -0
  44. build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +174 -115
  45. build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +0 -7
  46. build/torch210-cxx11-rocm70-x86_64-linux/metadata.json +3 -1
  47. build/torch210-cxx11-rocm70-x86_64-linux/muon.py +196 -870
  48. build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +50 -0
  49. build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py +390 -0
  50. build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +129 -0
CLAUDE.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).
8
+
9
+ ## Commands
10
+
11
+ ### Lint & Format
12
+
13
+ ```bash
14
+ pre-commit run --all-files # Run all pre-commit hooks
15
+ pre-commit run isort --all-files # Run a specific hook (e.g., isort)
16
+ ```
17
+
18
+ Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).
19
+
20
+ ### Tests
21
+
22
+ Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0.
23
+
24
+ ```bash
25
+ cd test && ./run_test.sh
26
+ # Equivalent to:
27
+ cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py
28
+ ```
29
+
30
+ Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation).
31
+
32
+ ### Build
33
+
34
+ Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`.
35
+
36
+ ### Commit Convention
37
+
38
+ **Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches.
39
+
40
+ ## Architecture
41
+
42
+ ### Source Layout
43
+
44
+ ```
45
+ torch-ext/optimizer/
46
+ ├── __init__.py # Public API: exports Muon
47
+ ├── muon.py # Muon optimizer class (~430 lines)
48
+ ├── newton_schulz.py # Newton-Schulz iteration (~50 lines)
49
+ ├── qk_clip.py # QK clipping for attention heads (~130 lines)
50
+ ├── core.py # Shared state, helpers, param grouping (~110 lines)
51
+ ├── pipeline.py # Async generator pipeline for parallel mode (~290 lines)
52
+ ├── async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines)
53
+ ├── adamw.py # Fused AdamW for non-Muon parameters (~160 lines)
54
+ ├── matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines)
55
+ └── distributed/
56
+ └── utils.py # Shard mesh construction, DTensor slicing (~175 lines)
57
+ ```
58
+
59
+ ### Optimizer Modes
60
+
61
+ The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:
62
+
63
+ 1. **Base mode** (`base()`) — Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization.
64
+ 2. **Distributed mode** (`distributed_muon()`) — Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback.
65
+ 3. **Parallel mode** (`parallel()`) — Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature.
66
+
67
+ ### Parallel Mode Pipeline
68
+
69
+ The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through:
70
+
71
+ ```
72
+ build bufs + async all2all_gather → yield → wait + Newton-Schulz compute + async all2all_scatter → yield → wait + update_param
73
+ ```
74
+
75
+ The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield.
76
+
77
+ `warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`.
78
+
79
+ For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md).
80
+
81
+ ### Key Abstractions
82
+
83
+ - **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) — Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default.
84
+ - **`_muon_state` dataclass** (`core.py`) — Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state.
85
+ - **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.
86
+ - **`run_pipeline()`** (`async_utils.py`) — Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.
87
+ - **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) — Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+).
88
+ - **Newton-Schulz iteration** (`newton_schulz.py`) — `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T.
89
+ - **QK Clipping** (`qk_clip.py`) — Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`.
90
+ - **Fused AdamW** (`adamw.py`) — Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.
91
+
92
+ ### Dependency Graph
93
+
94
+ ```
95
+ matmul_transpose_triton.py (leaf)
96
+
97
+ newton_schulz.py (leaf + triton)
98
+
99
+ core.py ──── qk_clip.py (leaf, distributed/utils)
100
+ │ │ │
101
+ │ pipeline.py ─── async_utils.py
102
+ │ │
103
+ │ adamw.py
104
+ │ │
105
+ muon.py (all above)
106
+
107
+ __init__.py
108
+ ```
README.md CHANGED
@@ -45,7 +45,13 @@ optim = optimizer.Muon(
45
  )
46
  ```
47
 
 
 
 
 
 
48
  ## Test
 
49
  - Check [test/README.md](./test/README.md) for how to run the tests.
50
 
51
  ## Pre-commit Hooks
 
45
  )
46
  ```
47
 
48
+ ## Documentation
49
+
50
+ - [Implementation Guide](./docs/implementation.md) — Detailed walkthrough of the internal architecture, parallel pipeline, distributed utilities, and QK clipping. Recommended for code reviewers and new contributors.
51
+ - [PyTorch 2.10 TP Fix](./docs/pytorch-2.10-tp-fix.md) — Root cause analysis and fixes for `_StridedShard` compatibility with PyTorch 2.10+.
52
+
53
  ## Test
54
+
55
  - Check [test/README.md](./test/README.md) for how to run the tests.
56
 
57
  ## Pre-commit Hooks
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_7aef62f_dirty
3
+ ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_7aef62f_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf
3
  size 1940944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f095be87ff6185010a3cff4175abbde0b2e50fe1e435dc1db4eaf5bf1f6199ca
3
  size 1940944
build/torch210-cxx11-cu126-x86_64-linux/adamw.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch.distributed.tensor import DTensor
6
+
7
+
8
+ def fused_adamw(
9
+ params: list[torch.Tensor],
10
+ grads: list[torch.Tensor],
11
+ exp_avgs: list[torch.Tensor],
12
+ exp_avg_sqs: list[torch.Tensor],
13
+ max_exp_avg_sqs: list[torch.Tensor],
14
+ state_steps: list[torch.Tensor],
15
+ amsgrad: bool,
16
+ beta1: float,
17
+ beta2: float,
18
+ lr: float | torch.Tensor,
19
+ weight_decay: float,
20
+ eps: float,
21
+ maximize: bool,
22
+ ) -> None:
23
+ if not params:
24
+ return
25
+
26
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
27
+ # treating it as a scalar.
28
+ lr_dict: dict | None = ({
29
+ lr.device: lr
30
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
31
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
32
+ [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
33
+ state_steps] # type: ignore[list-item]
34
+ )
35
+ for (device, _), (
36
+ (
37
+ device_params_,
38
+ device_grads_,
39
+ device_exp_avgs_,
40
+ device_exp_avg_sqs_,
41
+ device_max_exp_avg_sqs,
42
+ device_state_steps_,
43
+ ),
44
+ _,
45
+ ) in grouped_tensors.items():
46
+ device_params = cast(list[torch.Tensor], device_params_)
47
+ device_grads = cast(list[torch.Tensor], device_grads_)
48
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
49
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
50
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
51
+
52
+ if lr_dict is not None and device not in lr_dict:
53
+ lr_dict[device] = lr.to(
54
+ device=device, non_blocking=True) # type: ignore[union-attr]
55
+ lr = lr_dict[device]
56
+ torch._foreach_add_(device_state_steps, 1)
57
+ func = torch._fused_adamw_
58
+ func(
59
+ device_params,
60
+ device_grads,
61
+ device_exp_avgs,
62
+ device_exp_avg_sqs,
63
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
64
+ device_state_steps,
65
+ amsgrad=amsgrad,
66
+ lr=lr, # type: ignore[arg-type]
67
+ beta1=beta1,
68
+ beta2=beta2,
69
+ weight_decay=weight_decay,
70
+ eps=eps,
71
+ maximize=maximize,
72
+ )
73
+
74
+
75
+ def step_adamw_params(optimizer_state, params, group):
76
+ """Run fused AdamW on a list of parameters sharing the same placement.
77
+
78
+ Args:
79
+ optimizer_state: The optimizer's state dict (self.state in Muon).
80
+ params: List of parameters to update.
81
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
+ """
83
+ params_with_grads = []
84
+ grads = []
85
+ moment1 = []
86
+ moment2 = []
87
+ max_exp_avg_sqs = []
88
+ state_steps = []
89
+ lr = group["lr"]
90
+ beta1, beta2 = group["adamw_betas"]
91
+ eps = group["adamw_eps"]
92
+ weight_decay = group["weight_decay"]
93
+
94
+ for p in params:
95
+ g = p.grad
96
+ if g is None:
97
+ continue
98
+ state = optimizer_state[p]
99
+ params_with_grads.append(p)
100
+ grads.append(g)
101
+ if "step" not in state:
102
+ state["step"] = (torch.zeros((),
103
+ dtype=torch.float32,
104
+ device=p.device))
105
+ state["moment1"] = torch.zeros_like(g)
106
+ state["moment2"] = torch.zeros_like(g)
107
+ moment1.append(state["moment1"])
108
+ moment2.append(state["moment2"])
109
+ if not isinstance(state["step"], torch.Tensor):
110
+ step_tensor = torch.tensor(state["step"],
111
+ dtype=torch.float32,
112
+ device=p.device)
113
+ else:
114
+ step_tensor = state["step"]
115
+ state_steps.append(step_tensor)
116
+
117
+ fused_adamw(
118
+ params_with_grads,
119
+ grads,
120
+ moment1,
121
+ moment2,
122
+ max_exp_avg_sqs,
123
+ state_steps,
124
+ amsgrad=False,
125
+ beta1=beta1,
126
+ beta2=beta2,
127
+ lr=lr,
128
+ weight_decay=weight_decay,
129
+ eps=eps,
130
+ maximize=False,
131
+ )
132
+
133
+
134
+ def step_adamw(optimizer_state, group):
135
+ """Dispatch AdamW step, grouping parameters by type and placement.
136
+
137
+ Args:
138
+ optimizer_state: The optimizer's state dict (self.state in Muon).
139
+ group: Parameter group dict.
140
+ """
141
+ params = group["params"]
142
+
143
+ # group params with its type and placement
144
+ placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
+ for p in params:
146
+ match p:
147
+ case DTensor():
148
+ placement_to_params[tuple([p.placements,
149
+ p.device_mesh])].append(p)
150
+ case torch.Tensor():
151
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
152
+
153
+ for group_params in placement_to_params.values():
154
+ step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu126-x86_64-linux/async_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class _Task:
8
+ """Internal: wraps a generator, advances one yield at a time."""
9
+
10
+ def __init__(self, generator: Generator[None, None, None], index: int):
11
+ self._generator = generator
12
+ self._index = index
13
+ self._steps_completed = 0
14
+ self.step() # run to first yield
15
+
16
+ def step(self) -> bool:
17
+ try:
18
+ next(self._generator)
19
+ self._steps_completed += 1
20
+ logger.debug("pipeline[%d] completed stage %d", self._index,
21
+ self._steps_completed)
22
+ return True
23
+ except StopIteration:
24
+ logger.debug("pipeline[%d] finished after %d stages", self._index,
25
+ self._steps_completed)
26
+ return False
27
+
28
+ def close(self):
29
+ self._generator.close()
30
+
31
+
32
+ def run_pipeline(
33
+ pipelines: Generator[Generator[None, None, None], None, None],
34
+ max_concurrent: int,
35
+ ) -> None:
36
+ """Run generator-based pipelines with bounded concurrency.
37
+
38
+ Each pipeline is a generator that yields at stage boundaries.
39
+ The runtime interleaves pipelines so communication and computation
40
+ overlap across chunks.
41
+ """
42
+ if max_concurrent <= 0:
43
+ raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
+
45
+ have_new = True
46
+ task_index = 0
47
+ previous_tasks: list[_Task] = []
48
+
49
+ try:
50
+ while have_new or previous_tasks:
51
+ running_tasks: list[_Task] = []
52
+
53
+ # Admit one new pipeline per iteration (staggered admission).
54
+ # Admitting one at a time ensures that while chunk N does NS
55
+ # compute on the default stream, chunk N+1's NCCL all-to-all
56
+ # runs concurrently on the NCCL stream — creating real
57
+ # communication/computation overlap on the GPU.
58
+ if have_new and len(previous_tasks) < max_concurrent:
59
+ try:
60
+ gen = next(pipelines)
61
+ task = _Task(gen, task_index)
62
+ task_index += 1
63
+ running_tasks.append(task)
64
+ except StopIteration:
65
+ have_new = False
66
+
67
+ # Advance every previously-yielded task by one step.
68
+ for task in previous_tasks:
69
+ if task.step():
70
+ running_tasks.append(task)
71
+
72
+ previous_tasks = running_tasks
73
+ except BaseException:
74
+ # Clean up all in-flight generators to release GPU resources.
75
+ for task in previous_tasks:
76
+ task.close()
77
+ raise
build/torch210-cxx11-cu126-x86_64-linux/core.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed import ProcessGroup
7
+ from torch.distributed.tensor import DTensor
8
+
9
+
10
+ @dataclass
11
+ class _muon_state:
12
+ worker_rank: int
13
+ process_group: ProcessGroup
14
+ rank_indices: dict[int, tuple] # local_rank -> per-dim indices
15
+ rank_numels: dict[int, int] # local_rank -> numel
16
+ name: str
17
+ qk_clip_state: torch.Tensor | None = None
18
+
19
+
20
+ def update_g(optimizer_state, p, g, group, momentum):
21
+ """Apply momentum update to gradient.
22
+
23
+ Args:
24
+ optimizer_state: The optimizer's state dict (self.state in Muon).
25
+ p: Parameter tensor.
26
+ g: Gradient tensor.
27
+ group: Parameter group dict.
28
+ momentum: Momentum coefficient.
29
+
30
+ Returns:
31
+ Momentum-updated gradient tensor.
32
+ """
33
+ state = optimizer_state[p]
34
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
+ torch.add(g, buf, alpha=momentum, out=buf)
36
+ if group["nesterov"]:
37
+ g.add_(buf, alpha=momentum)
38
+ return g
39
+ return buf
40
+
41
+
42
+ def update_p(p, u, lr, adjusted_lr, weight_decay):
43
+ """Apply weight decay and orthogonalized update to parameter.
44
+
45
+ Args:
46
+ p: Parameter (torch.nn.Parameter or DTensor).
47
+ u: Orthogonalized update tensor.
48
+ lr: Base learning rate.
49
+ adjusted_lr: Size-adjusted learning rate.
50
+ weight_decay: Weight decay coefficient.
51
+ """
52
+ if isinstance(p, torch.nn.Parameter):
53
+ # apply weight decay
54
+ p.data.mul_(1 - lr * weight_decay)
55
+ # apply update
56
+ p.data.add_(u, alpha=-adjusted_lr)
57
+ else:
58
+ p.mul_(1 - lr * weight_decay)
59
+ p.add_(u, alpha=-adjusted_lr)
60
+
61
+
62
+ def adjust_lr_for_muon(lr, param_shape):
63
+ """Scale learning rate based on parameter matrix dimensions.
64
+
65
+ Args:
66
+ lr: Base learning rate.
67
+ param_shape: Shape of the parameter tensor.
68
+
69
+ Returns:
70
+ Adjusted learning rate.
71
+ """
72
+ A, B = param_shape[:2]
73
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
74
+ # as described in the paper
75
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
76
+ adjusted_lr = lr * adjusted_ratio
77
+ return adjusted_lr
78
+
79
+
80
+ def default_is_muon(name, x, expert_keys=None):
81
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
+ if any(key in name for key in skip_keys):
83
+ return False
84
+ effective_ndim = x.ndim
85
+ if expert_keys and any(key in name for key in expert_keys):
86
+ effective_ndim -= 1
87
+ return effective_ndim >= 2
88
+
89
+
90
+ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
91
+ if is_muon_func is None:
92
+ is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
+
94
+ muon_params, muon_names = [], []
95
+ non_muon_params = []
96
+
97
+ for n, p in model.named_parameters():
98
+ if not p.requires_grad:
99
+ continue
100
+ if is_muon_func(n, p):
101
+ muon_params.append(p)
102
+ muon_names.append(n)
103
+ else:
104
+ non_muon_params.append(p)
105
+
106
+ return [
107
+ {
108
+ "params": muon_params,
109
+ "names": muon_names,
110
+ "use_muon": True,
111
+ },
112
+ {
113
+ "params": non_muon_params,
114
+ "use_muon": False,
115
+ },
116
+ ]
build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py CHANGED
@@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
  _StridedShard)
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  def get_slices_of_dtensor(
11
  target: DTensor | torch.Tensor,
12
  local_rank: int,
13
  shard_mesh: DeviceMesh,
14
  shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
  """
17
- Get the slice of local tensor for a given rank from a tensor.
 
 
 
 
 
18
  Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
  shard_placements (tuple[Placement]): The shard placements.
23
- """
24
 
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
 
 
 
 
26
 
27
  # find the global rank of the local rank in the shard mesh
28
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
@@ -34,34 +52,75 @@ def get_slices_of_dtensor(
34
 
35
  assert len(rank_coords) == len(shard_placements)
36
 
 
 
 
 
37
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
  # left-to-right sharding. This is ensured by the sorting logic of
39
  # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
 
44
- num_ranks = shard_mesh.mesh.shape[i]
 
45
 
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
 
 
 
48
 
49
- if dim_size % num_ranks != 0:
50
  raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- return tuple(slices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
@@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
71
  def construct_shard_mesh(
72
  placements: tuple[Placement],
73
  mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
 
81
- assert mesh.mesh.device.type == 'cpu'
 
 
82
 
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
 
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
 
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
 
121
 
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
 
 
124
 
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
 
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
  else:
137
  shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
  assert len(shard_placements) == len(set(shard_placements))
142
 
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
  )
171
 
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
7
  _StridedShard)
8
 
9
 
10
+ def _is_shard(placement: Placement) -> bool:
11
+ """Check if a placement is a shard type (Shard or _StridedShard).
12
+
13
+ In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
+ ``placement.is_shard()`` returns False for _StridedShard. This helper
15
+ handles both old and new hierarchies.
16
+ """
17
+ return isinstance(placement, (Shard, _StridedShard))
18
+
19
+
20
  def get_slices_of_dtensor(
21
  target: DTensor | torch.Tensor,
22
  local_rank: int,
23
  shard_mesh: DeviceMesh,
24
  shard_placements: tuple[Placement],
25
+ ) -> tuple[slice | torch.Tensor, ...]:
26
  """
27
+ Get per-dimension indices for a given rank's shard of the target tensor.
28
+
29
+ Uses ``Shard.local_shard_size_and_offset`` and
30
+ ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
+ both contiguous and strided (non-contiguous) sharding.
32
+
33
  Args:
34
+ target (DTensor | torch.Tensor): The target tensor (for its shape).
35
+ local_rank (int): The local rank within the shard group.
36
+ shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
  shard_placements (tuple[Placement]): The shard placements.
 
38
 
39
+ Returns:
40
+ A tuple of indices (one per tensor dim). Each element is either:
41
+ - A ``slice`` (for contiguous or unsharded dims)
42
+ - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
+ """
44
 
45
  # find the global rank of the local rank in the shard mesh
46
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
 
52
 
53
  assert len(rank_coords) == len(shard_placements)
54
 
55
+ # Track per-shard-dim indices.
56
+ # None means "not yet sharded on this dim".
57
+ dim_indices: dict[int, torch.Tensor] = {}
58
+
59
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
  # left-to-right sharding. This is ensured by the sorting logic of
61
  # construct_shard_mesh function.
62
+ for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
+ zip(rank_coords, shard_placements)):
64
+ assert _is_shard(placement)
65
 
66
+ num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
+ shard_dim = placement.dim
68
 
69
+ # Current effective size on this dim (may already be sub-sharded)
70
+ if shard_dim in dim_indices:
71
+ curr_size = len(dim_indices[shard_dim])
72
+ else:
73
+ curr_size = target.size()[shard_dim]
74
 
75
+ if curr_size % num_chunks != 0:
76
  raise NotImplementedError(
77
+ f"Dimension size {curr_size} is not divisible "
78
+ f"by number of ranks {num_chunks} for shard "
79
+ f"placement on dim {shard_dim}. (shape: {target.shape})")
80
+
81
+ # Compute indices for this level of sharding
82
+ if isinstance(placement, _StridedShard):
83
+ _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
84
+ placement,
85
+ curr_size,
86
+ num_chunks,
87
+ rank_coord,
88
+ return_first_offset=False)
89
+ new_indices = torch.tensor(offsets, dtype=torch.long)
90
+ else:
91
+ shard_size, offset = Shard.local_shard_size_and_offset(
92
+ curr_size, num_chunks, rank_coord)
93
+ new_indices = torch.arange(offset,
94
+ offset + shard_size,
95
+ dtype=torch.long)
96
+
97
+ # Compose with previous indices on this dim
98
+ if shard_dim in dim_indices:
99
+ dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
100
+ else:
101
+ dim_indices[shard_dim] = new_indices
102
 
103
+ # Build result tuple
104
+ result: list[slice | torch.Tensor] = []
105
+ for d in range(len(target.size())):
106
+ if d not in dim_indices:
107
+ result.append(slice(None))
108
+ else:
109
+ indices = dim_indices[d]
110
+ # Convert contiguous indices to slice for efficiency
111
+ if len(indices) > 0:
112
+ start = indices[0].item()
113
+ expected = torch.arange(start,
114
+ start + len(indices),
115
+ dtype=torch.long)
116
+ if torch.equal(indices, expected):
117
+ result.append(slice(start, start + len(indices)))
118
+ else:
119
+ result.append(indices)
120
+ else:
121
+ result.append(slice(0, 0))
122
+
123
+ return tuple(result)
124
 
125
 
126
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
 
130
  def construct_shard_mesh(
131
  placements: tuple[Placement],
132
  mesh: DeviceMesh,
133
+ ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
134
+ """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
 
 
 
 
135
 
136
+ Given a DTensor's placements and device mesh, extracts the "shard group"
137
+ — the set of ranks that together hold all shards of the same replica —
138
+ and creates a ProcessGroup for all-to-all among them.
139
 
140
+ Steps:
141
+ 1. Sort placements: Replicate first, then Shard by (dim, granularity).
142
+ 2. Permute the mesh tensor to match the sorted order.
143
+ 3. Collapse Replicate dims list of shard sub-meshes (one per replica).
144
+ 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ Example — 8 GPUs, mesh shape (2, 2, 2),
147
+ placements ``[Shard(0), Replicate, _StridedShard(0)]``::
 
 
148
 
149
+ Step 1 Sort: [Replicate, _StridedShard(0), Shard(0)]
150
+ Permutation: [1, 2, 0]
151
 
152
+ Step 2 Permute mesh dims by [1, 2, 0]:
153
+ Original: Permuted:
154
+ [[[0,1],[2,3]], [[[0,2],[1,3]],
155
+ [[4,5],[6,7]]] [[4,6],[5,7]]]
156
 
157
+ Step 3 Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
158
+ sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
159
+ sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
160
+ shard_placements = (_StridedShard(0), Shard(0))
161
 
162
+ Step 4 Rank 0 → ProcessGroup([0,1,4,5])
163
+ Rank 2 ProcessGroup([2,3,6,7])
164
+
165
+ Returns:
166
+ ``(shard_mesh, process_group, shard_placements)``
167
+ """
168
+ my_rank = dist.get_rank()
169
+ assert mesh.mesh.device.type == 'cpu'
170
+
171
+ # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
172
+ # This avoids a non-collective dist.new_group() call, which would
173
+ # deadlock when only a subset of ranks call this function (e.g. expert
174
+ # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
175
+ if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
176
+ key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
177
+ if key not in _ranks_to_dist_cache:
178
+ _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
179
+ return (*_ranks_to_dist_cache[key], tuple(placements))
180
+
181
+ mesh_tensor = mesh.mesh.clone()
182
+
183
+ # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
184
+ # _StridedShard comes BEFORE regular Shard on the same dim so that
185
+ # get_slices_of_dtensor applies the outer sharding first, matching
186
+ # DTensor's left-to-right (outer-to-inner) composition order.
187
+ def _sort_key(item):
188
+ index, placement = item
189
+ assert not placement.is_partial(), "Partial placement not supported"
190
+ if placement.is_replicate():
191
+ return (-1, 0, index)
192
+ assert _is_shard(placement), f"Unsupported: {type(placement)}"
193
+ split = (-1 / placement.split_factor if isinstance(
194
+ placement, _StridedShard) else 0)
195
+ return (placement.dim, split, index)
196
+
197
+ indexed = sorted(enumerate(placements), key=_sort_key)
198
+ perm, sorted_placements = zip(*indexed)
199
+
200
+ # -- Step 2: Permute mesh to match sorted placement order. --------------
201
+ sorted_mesh = mesh_tensor.permute(perm)
202
+
203
+ # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
204
+ # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
205
+ num_rep = sum(1 for p in sorted_placements if p.is_replicate())
206
+ if num_rep > 0:
207
+ if num_rep > 1:
208
+ sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
209
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
210
  else:
211
  shard_meshes = [sorted_mesh]
212
+ shard_placements = sorted_placements[num_rep:]
 
 
213
  assert len(shard_placements) == len(set(shard_placements))
214
 
215
+ # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
216
+ # All ranks must call dist.new_group in the same order, even though each
217
+ # rank only joins one group.
218
+ def _cache_key(t: torch.Tensor) -> tuple:
219
+ return (*t.shape, *t.flatten().tolist())
220
+
221
+ my_key = None
222
+ for sm in shard_meshes:
223
+ key = _cache_key(sm)
224
+ if (my_rank == sm).any().item():
225
+ assert my_key is None, "Rank appears in multiple shard groups"
226
+ my_key = key
227
+ if key not in _ranks_to_dist_cache:
228
+ pg = dist.new_group(sm.flatten().tolist())
229
+ _ranks_to_dist_cache[key] = (
230
+ DeviceMesh(device_type="cuda", mesh=sm),
231
+ pg,
 
 
 
 
 
 
 
 
 
 
232
  )
233
 
234
+ return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json CHANGED
@@ -1 +1,3 @@
1
- {"python-depends":[]}
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -1,536 +1,121 @@
1
  import logging
2
- import math
3
  import types
4
  from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
 
 
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
 
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
 
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
 
 
406
 
407
- assert inner_off == block
408
- off += block
409
 
 
410
 
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
  """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
  continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
 
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- return None, -1
 
524
 
 
525
 
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
@@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer):
554
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
  weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
  adamw_lr: The learning rate for the internal AdamW.
559
  adamw_betas: The betas for the internal AdamW.
560
  adamw_eps: The epsilon for the internal AdamW.
@@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer):
564
  - "q_indices" (list[int]): Indices of query heads to consider.
565
  - "k_indices" (list[int]): Indices of key heads to consider.
566
  - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
  this value will be scaled down.
569
  Default is:
570
  {
@@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer):
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
 
 
 
 
 
 
 
587
  """
588
 
589
  def __init__(self,
@@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer):
597
  adamw_eps=1e-8,
598
  none_grad=True,
599
  debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
  warmup_step=5,
607
  chunk_size=-1,
608
  use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
 
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
@@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer):
630
 
631
  super().__init__(params, defaults)
632
 
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
  self.debug = debug
638
- self.clip_config = clip_config
 
 
 
 
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
  self.small_param_numel_threshold = small_param_numel_threshold
 
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
@@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer):
649
 
650
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
 
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
  def get_shard_mesh(self, p):
667
  """
668
  Get the shard mesh for a parameter p on the given rank.
@@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer):
673
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
  p.placements, p.device_mesh)
675
 
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
  return shard_mesh, shard_pg, shard_placements
680
 
681
  def init_state_and_assign_params(self, names, params, group, qk_logits):
@@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer):
694
  total_flops += flops
695
 
696
  if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
 
700
  paired = list(zip(names, params))
701
 
@@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer):
724
 
725
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  param_to_state[id(p)] = _muon_state(
730
  worker_rank=worker_rank,
731
  process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
  name=n,
735
  qk_clip_state=qk_clip_state,
736
  )
737
 
738
  return param_to_state, ordered_params
739
 
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
  for n, p in zip(names, params):
744
  g = p.grad
745
  if g is None:
746
  continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
755
 
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
 
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
 
761
- scales_full = self._compute_scales(
762
  p, qk_clip_state) if qk_clip_state is not None else None
763
  if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
 
766
  def distributed_muon(
767
  self,
@@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer):
770
  group: dict[str, Any],
771
  lr: float,
772
  weight_decay: float,
773
- momentum: float,
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
777
 
 
778
  for n, p in zip(names, params):
779
  g = p.grad
780
  if g is None:
781
  continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
@@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer):
796
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
  steps=group["ns_steps"])
798
 
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
 
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
 
804
- scales_full = self._compute_scales(
805
  p_full, qk_clip_state) if qk_clip_state is not None else None
806
 
807
  if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
  ndims = len(p.device_mesh.mesh.shape)
@@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer):
822
 
823
  p.copy_(p_sharded)
824
 
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
  """
911
  Perform a parallel optimization step using Muon.
912
- """
913
 
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
 
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
 
930
  param_to_state, ordered_params = self.init_state_and_assign_params(
931
  names, params, group, qk_logits)
932
 
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
 
967
  if self.chunk_size == -1:
968
  shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
  elif self.chunk_size > 0:
972
  chunk_size = self.chunk_size
973
  else:
974
  raise ValueError("chunk_size must be -1 or a positive integer.")
975
 
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
 
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
 
1064
  def _step_muon(self, group, qk_logits=None):
1065
  params = group["params"]
@@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer):
1068
  momentum = group["momentum"]
1069
  names = group["names"]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  param_dtensors = []
1072
  name_dtensors = []
1073
 
@@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer):
1083
  group=group,
1084
  lr=lr,
1085
  weight_decay=weight_decay,
1086
- momentum=momentum,
1087
  qk_logits=qk_logits)
1088
  return
1089
 
@@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer):
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
  assert len(dtensors) == len(names)
1125
  for p, n in zip(dtensors, names):
@@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer):
1141
  group=group,
1142
  lr=lr,
1143
  weight_decay=weight_decay,
1144
- momentum=momentum,
1145
  qk_logits=qk_logits,
1146
  )
1147
 
@@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer):
1159
  group,
1160
  lr=lr,
1161
  weight_decay=weight_decay,
1162
- momentum=momentum,
1163
  qk_logits=qk_logits,
1164
  )
1165
 
@@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer):
1170
  group,
1171
  lr=lr,
1172
  weight_decay=weight_decay,
1173
- momentum=momentum,
1174
  qk_logits=qk_logits,
1175
  )
1176
 
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
  @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
@@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer):
1249
  Args:
1250
  closure (Callable, optional): A closure that reevaluates the model
1251
  and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
  (1 / sqrt(head_dim)) * (Q @ K^T).
1256
  """
1257
  loss = None
@@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer):
1263
  if group["use_muon"]:
1264
  self._step_muon(group, qk_logits=qk_logits)
1265
  else:
1266
- self._step_adamw(group)
1267
 
1268
  return loss
 
1
  import logging
 
2
  import types
3
  from collections import defaultdict
4
+ from typing import Any
 
5
 
6
  import torch
7
  import torch.distributed as dist
8
+ from torch.distributed.tensor import DTensor, Replicate, Shard
9
+ from torch.profiler import record_function
10
+
11
+ from .adamw import step_adamw
12
+ from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, update_g, update_p)
15
+ from .distributed.utils import (_is_shard, construct_shard_mesh,
16
+ get_slices_of_dtensor)
17
+ from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
+ _zeropower_via_newtonschulz5)
19
+ from .pipeline import muon_chunk_pipeline
20
+ from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def _expand_expert_params(names, params, expert_keys):
26
+ """Expand expert params by splitting on dim 0 (expert dimension).
 
27
 
28
+ Params whose name matches any key in ``expert_keys`` are treated as
29
+ expert-parallel tensors. Their outermost dimension is the expert
30
+ dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
31
+ ``nn.Parameter`` views so that in-place updates propagate back to
32
+ the original storage.
33
 
34
+ Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
35
+ if they are expert params, their key must be added to ``expert_keys``.
36
 
37
+ The grad must already be set on each expert param (e.g. after momentum).
38
 
39
+ For DTensor expert params, placements that shard on dim 0 (expert dim)
40
+ are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
41
+ preserved: each 2D slice is wrapped as a DTensor on the corresponding
42
+ submesh so the parallel pipeline handles the TP communication.
 
43
  """
44
+ expanded_names = []
45
+ expanded_params = []
46
+
47
+ for n, p in zip(names, params):
48
+ is_expert = expert_keys and any(key in n for key in expert_keys)
49
+ is_dtensor = isinstance(p.data, DTensor)
50
+
51
+ if not is_expert:
52
+ assert p.data.ndim <= 2, (
53
+ f"Param {n} has ndim={p.data.ndim} but does not match "
54
+ f"expert_keys={expert_keys}. If this is an expert param, "
55
+ f"add its key to expert_keys.")
56
+ expanded_names.append(n)
57
+ expanded_params.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ g = p.grad
61
+ assert g is not None, (
62
+ f"Expert param {n} must have grad set before expansion")
63
+
64
+ tp_mesh = None
65
+ tp_placements_2d = None
66
+
67
+ if is_dtensor:
68
+ local_data = p.to_local()
69
+ local_grad = g.to_local() if isinstance(g, DTensor) else g
70
+
71
+ # Find non-dim-0 shard placements (e.g. TP sharding).
72
+ # After splitting on dim 0, Shard(k) becomes Shard(k-1).
73
+ tp_dim_indices = []
74
+ tp_placements_2d = []
75
+ for i, pl in enumerate(p.placements):
76
+ if _is_shard(pl) and pl.dim != 0:
77
+ tp_dim_indices.append(i)
78
+ tp_placements_2d.append(Shard(pl.dim - 1))
79
+
80
+ if tp_dim_indices:
81
+ tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
82
+ for i in tp_dim_indices)
83
+ if len(tp_dim_names) == 1:
84
+ tp_mesh = p.device_mesh[tp_dim_names[0]]
85
+ else:
86
+ tp_mesh = p.device_mesh[tp_dim_names]
87
+ else:
88
+ local_data = p.data
89
+ local_grad = g
90
+
91
+ # Expand: split dim 0, reshape each slice to 2D.
92
+ num_local_experts = local_data.shape[0]
93
+ for i in range(num_local_experts):
94
+ slice_data = local_data[i]
95
+ slice_grad = local_grad[i]
96
+
97
+ if tp_mesh is not None:
98
+ # Wrap as DTensor on TP submesh so the pipeline handles
99
+ # TP communication (gather/scatter across TP ranks).
100
+ dt_data = DTensor.from_local(slice_data,
101
+ device_mesh=tp_mesh,
102
+ placements=tp_placements_2d)
103
+ dt_grad = DTensor.from_local(slice_grad,
104
+ device_mesh=tp_mesh,
105
+ placements=tp_placements_2d)
106
+ expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
107
+ expert_param.grad = dt_grad
108
+ else:
109
+ expert_param = torch.nn.Parameter(slice_data,
110
+ requires_grad=False)
111
+ expert_param.grad = slice_grad
112
 
113
+ expanded_names.append(f"{n}[{i}]")
114
+ expanded_params.append(expert_param)
115
 
116
+ p.grad = None # allow expert grad storage to be freed after pipeline
117
 
118
+ return expanded_names, expanded_params
 
 
 
 
 
 
 
119
 
120
 
121
  class Muon(torch.optim.Optimizer):
 
139
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
140
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
141
  weight_decay: The weight decay for Muon and AdamW.
142
+ Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
143
  adamw_lr: The learning rate for the internal AdamW.
144
  adamw_betas: The betas for the internal AdamW.
145
  adamw_eps: The epsilon for the internal AdamW.
 
149
  - "q_indices" (list[int]): Indices of query heads to consider.
150
  - "k_indices" (list[int]): Indices of key heads to consider.
151
  - "head_dim" (int): Dimensionality of each attention head.
152
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
153
  this value will be scaled down.
154
  Default is:
155
  {
 
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
+ expert_keys: List of strings to identify expert-parallel parameters.
173
+ If any key appears in a parameter's name, its outermost
174
+ dimension is treated as the expert dimension and expanded
175
+ into per-expert 2D params for Muon. For example,
176
+ ``expert_keys=["experts"]`` matches any param whose name
177
+ contains "experts". 3D+ params not matched by any key
178
+ will raise an error.
179
  """
180
 
181
  def __init__(self,
 
189
  adamw_eps=1e-8,
190
  none_grad=True,
191
  debug=False,
192
+ clip_config=None,
 
 
 
 
 
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
+ small_param_numel_threshold=65536,
197
+ expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
 
218
 
219
  super().__init__(params, defaults)
220
 
 
 
 
 
221
  self.debug = debug
222
+ self.clip_config = clip_config if clip_config is not None else {
223
+ "q_indices": [],
224
+ "k_indices": [],
225
+ "head_dim": 128,
226
+ "threshold": 100,
227
+ }
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
  self.small_param_numel_threshold = small_param_numel_threshold
232
+ self.expert_keys = expert_keys
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
 
239
 
240
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def get_shard_mesh(self, p):
243
  """
244
  Get the shard mesh for a parameter p on the given rank.
 
249
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
250
  p.placements, p.device_mesh)
251
 
 
 
 
252
  return shard_mesh, shard_pg, shard_placements
253
 
254
  def init_state_and_assign_params(self, names, params, group, qk_logits):
 
267
  total_flops += flops
268
 
269
  if self.debug:
270
+ logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
271
+ total_flops / 1e12)
272
 
273
  paired = list(zip(names, params))
274
 
 
297
 
298
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
299
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
300
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
301
+
302
+ # Precompute per-rank indices and numels for all-to-all.
303
+ rank_indices: dict[int, tuple] = {}
304
+ rank_numels: dict[int, int] = {}
305
+ for r in range(num_ranks):
306
+ indices = get_slices_of_dtensor(p, r, shard_mesh,
307
+ shard_placements)
308
+ rank_indices[r] = indices
309
+ numel = 1
310
+ for idx, dim_size in zip(indices, p.shape):
311
+ if isinstance(idx, slice):
312
+ start, stop, step = idx.indices(dim_size)
313
+ numel *= max(0, (stop - start + (step - 1)) // step)
314
+ else:
315
+ numel *= len(idx)
316
+ rank_numels[r] = numel
317
 
318
  param_to_state[id(p)] = _muon_state(
319
  worker_rank=worker_rank,
320
  process_group=shard_pg,
321
+ rank_indices=rank_indices,
322
+ rank_numels=rank_numels,
323
  name=n,
324
  qk_clip_state=qk_clip_state,
325
  )
326
 
327
  return param_to_state, ordered_params
328
 
329
+ def base(self, names, params, group, lr, weight_decay, qk_logits):
330
+ # Momentum is already applied by _step_muon before this method.
 
331
  for n, p in zip(names, params):
332
  g = p.grad
333
  if g is None:
334
  continue
 
 
 
 
 
335
 
336
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
  steps=group["ns_steps"])
338
 
339
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
+ update_p(p, u, lr, adjusted_lr, weight_decay)
341
 
342
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
343
 
344
+ scales_full = compute_scales(
345
  p, qk_clip_state) if qk_clip_state is not None else None
346
  if scales_full is not None:
347
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
348
 
349
  def distributed_muon(
350
  self,
 
353
  group: dict[str, Any],
354
  lr: float,
355
  weight_decay: float,
 
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
  """ Implementation of Distributed Muon by Liu et al. """
359
 
360
+ # Momentum is already applied by _step_muon before this method.
361
  for n, p in zip(names, params):
362
  g = p.grad
363
  if g is None:
364
  continue
 
 
 
 
 
365
 
366
  # Gather G
367
  if isinstance(p.data, DTensor):
 
374
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
  steps=group["ns_steps"])
376
 
377
+ adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
+ update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
381
 
382
+ scales_full = compute_scales(
383
  p_full, qk_clip_state) if qk_clip_state is not None else None
384
 
385
  if scales_full is not None:
386
+ qk_clip(p_full, scales_full, qk_clip_state.head_dim)
387
 
388
  if isinstance(p.data, DTensor):
389
  ndims = len(p.device_mesh.mesh.shape)
 
400
 
401
  p.copy_(p_sharded)
402
 
403
+ def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
 
406
 
407
+ Parameters are chunked and each chunk is processed by a
408
+ :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
409
+ interleaves multiple chunks so that communication and computation
410
+ overlap across chunks (the same overlap previously achieved by the
411
+ warmup + main-loop index scheduling).
412
+ """
413
 
414
+ # Momentum is already applied by _step_muon before this method.
 
 
 
 
 
 
 
415
 
416
  param_to_state, ordered_params = self.init_state_and_assign_params(
417
  names, params, group, qk_logits)
418
 
419
+ # Compute local rank for this group's shard process group.
420
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
421
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  if self.chunk_size == -1:
424
  shard_ranks = dist.get_world_size(param_to_state[id(
425
+ ordered_params[0])].process_group)
426
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
  elif self.chunk_size > 0:
428
  chunk_size = self.chunk_size
429
  else:
430
  raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
+ def pipelines():
433
+ for start in range(0, len(ordered_params), chunk_size):
434
+ chunk = ordered_params[start:start + chunk_size]
435
+ if chunk:
436
+ yield muon_chunk_pipeline(
437
+ params=chunk,
438
+ param_to_state=param_to_state,
439
+ rank=rank,
440
+ ns_steps=group["ns_steps"],
441
+ lr=lr,
442
+ weight_decay=weight_decay,
443
+ none_grad=group["none_grad"],
444
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ with record_function("muon::barrier"):
447
+ dist.barrier()
448
+ with record_function("muon::pipeline"):
449
+ run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  def _step_muon(self, group, qk_logits=None):
452
  params = group["params"]
 
455
  momentum = group["momentum"]
456
  names = group["names"]
457
 
458
+ # Apply momentum to all params before routing/expansion.
459
+ with record_function("muon::momentum"):
460
+ for n, p in zip(names, params):
461
+ g = p.grad
462
+ if g is None:
463
+ continue
464
+ g = update_g(self.state, p, g, group, momentum)
465
+ p.grad = g
466
+
467
+ # Expand expert params by splitting on dim 0.
468
+ names, params = _expand_expert_params(names, params, self.expert_keys)
469
+
470
  param_dtensors = []
471
  name_dtensors = []
472
 
 
482
  group=group,
483
  lr=lr,
484
  weight_decay=weight_decay,
 
485
  qk_logits=qk_logits)
486
  return
487
 
 
517
  # and run parallel Muon on each group.
518
 
519
  placement_to_params = defaultdict(lambda: ([], []))
 
520
 
521
  assert len(dtensors) == len(names)
522
  for p, n in zip(dtensors, names):
 
538
  group=group,
539
  lr=lr,
540
  weight_decay=weight_decay,
 
541
  qk_logits=qk_logits,
542
  )
543
 
 
555
  group,
556
  lr=lr,
557
  weight_decay=weight_decay,
 
558
  qk_logits=qk_logits,
559
  )
560
 
 
565
  group,
566
  lr=lr,
567
  weight_decay=weight_decay,
 
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
 
575
  Args:
576
  closure (Callable, optional): A closure that reevaluates the model
577
  and returns the loss.
578
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
579
+ to 1D tensors of shape (num_heads,), representing the maximum
580
+ QK logits across all tokens, computed as
581
  (1 / sqrt(head_dim)) * (Q @ K^T).
582
  """
583
  loss = None
 
589
  if group["use_muon"]:
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
592
+ step_adamw(self.state, group)
593
 
594
  return loss
build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .matmul_transpose_triton import matmul_transpose_assign
4
+
5
+ COMM_DTYPE = torch.bfloat16
6
+ DEFAULT_CHUNK_SIZE_RATIO = 4
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ # Muon's Newton–Schulz iteration causes high variance in singular values
12
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
13
+ @torch.no_grad()
14
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
+ def _zeropower_via_newtonschulz5(G, steps):
16
+ """
17
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
+ performance at all relative to UV^T, where USV^T = G is the SVD.
24
+ """
25
+ assert len(G.shape) == 2
26
+ assert G.dtype == COMM_DTYPE
27
+ X = G # no manual typecast
28
+
29
+ if G.size(0) > G.size(1):
30
+ X = X.T
31
+ # Ensure spectral norm is at most 1
32
+ X = X / (X.norm() + 1e-7)
33
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
+ # Perform the NS iterations
36
+ for a, b, c in [
37
+ (4.0848, -6.8946, 2.9270),
38
+ (3.9505, -6.3029, 2.6377),
39
+ (3.7418, -5.5913, 2.3037),
40
+ (2.8769, -3.1427, 1.2046),
41
+ (2.8366, -3.0525, 1.2012),
42
+ ]:
43
+ matmul_transpose_assign(X, buf1)
44
+ matmul_transpose_assign(buf1, buf2)
45
+ buf1.mul_(b).add_(buf2, alpha=c)
46
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
47
+
48
+ if G.size(0) > G.size(1):
49
+ X = X.T
50
+ return X
build/torch210-cxx11-cu126-x86_64-linux/pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ from .core import _muon_state, adjust_lr_for_muon, update_p
10
+ from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
+ from .qk_clip import compute_scales
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ======================================================================
16
+ # Stage helpers
17
+ # ======================================================================
18
+
19
+
20
+ def _launch_gather(
21
+ params: list[DTensor],
22
+ owned_params: list[DTensor],
23
+ param_to_state: dict[int, _muon_state],
24
+ rank: int,
25
+ num_ranks: int,
26
+ process_group: dist.ProcessGroup,
27
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
+ """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
+
30
+ Returns:
31
+ work: Async operation handle.
32
+ recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
+ gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
+ ``None`` for non-owned.
35
+ recv_counts: Per-source-rank element counts.
36
+ """
37
+ # Allocate gathered-grad buffers
38
+ gathered_grads: dict[int, torch.Tensor | None] = {}
39
+ for p in params:
40
+ state = param_to_state[id(p)]
41
+ if rank == state.worker_rank:
42
+ gathered_grads[id(p)] = torch.empty(p.shape,
43
+ dtype=COMM_DTYPE,
44
+ device="cuda")
45
+ else:
46
+ gathered_grads[id(p)] = None
47
+
48
+ # Build send buffer
49
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
+ send_counts = [0] * num_ranks
51
+
52
+ for p in params:
53
+ state = param_to_state[id(p)]
54
+ dst = state.worker_rank
55
+ assert dst < num_ranks
56
+ shard_elems = state.rank_numels[rank]
57
+ g = p.grad
58
+ g = g.to_local().to(COMM_DTYPE).contiguous()
59
+ assert g.numel() == shard_elems
60
+ per_dst[dst].append(g.view(-1))
61
+ send_counts[dst] += shard_elems
62
+
63
+ assert any(
64
+ len(v) > 0 for v in
65
+ per_dst), "At least one destination rank must receive a sharded tensor"
66
+ per_dst_flat = [t for dst in per_dst for t in dst]
67
+ send_buf = torch.cat(per_dst_flat, dim=0)
68
+
69
+ # Build recv buffer
70
+ recv_counts = [0] * num_ranks
71
+ for src in range(num_ranks):
72
+ total = 0
73
+ for p in owned_params:
74
+ state = param_to_state[id(p)]
75
+ assert state.worker_rank == rank
76
+ total += state.rank_numels[src]
77
+ recv_counts[src] = total
78
+
79
+ recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
80
+
81
+ # Launch async all-to-all
82
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
83
+ f"recv_buf size: {recv_buf.numel()}, "
84
+ f"recv_counts: {recv_counts}, "
85
+ f"send_counts: {send_counts}, "
86
+ f"process_group: {str(process_group)}")
87
+ work = dist.all_to_all_single(
88
+ recv_buf,
89
+ send_buf,
90
+ output_split_sizes=recv_counts,
91
+ input_split_sizes=send_counts,
92
+ group=process_group,
93
+ async_op=True,
94
+ )
95
+
96
+ return work, recv_buf, gathered_grads, recv_counts
97
+
98
+
99
+ def _complete_gather(
100
+ recv_buf: torch.Tensor,
101
+ recv_counts: list[int],
102
+ owned_params: list[DTensor],
103
+ gathered_grads: dict[int, torch.Tensor | None],
104
+ param_to_state: dict[int, _muon_state],
105
+ rank: int,
106
+ ) -> None:
107
+ """Reconstruct gathered grads from the recv buffer (in-place)."""
108
+ off = 0
109
+ for src in range(len(recv_counts)):
110
+ if recv_counts[src] == 0:
111
+ continue
112
+
113
+ block = recv_counts[src]
114
+ inner_off = 0
115
+ for p in owned_params:
116
+ state = param_to_state[id(p)]
117
+ assert state.worker_rank == rank
118
+
119
+ indices = state.rank_indices[src]
120
+
121
+ shard_view = gathered_grads[id(p)][indices]
122
+ n = shard_view.numel()
123
+ assert n > 0
124
+
125
+ sg = recv_buf.narrow(0, off + inner_off, n)
126
+ sg = sg.reshape(shard_view.shape)
127
+ gathered_grads[id(p)][indices] = sg
128
+
129
+ inner_off += n
130
+ assert inner_off == block
131
+ off += block
132
+
133
+
134
+ def _compute_ns(
135
+ owned_params: list[DTensor],
136
+ gathered_grads: dict[int, torch.Tensor | None],
137
+ ns_steps: int,
138
+ ) -> dict[int, torch.Tensor | None]:
139
+ """Run Newton-Schulz orthogonalization on owned parameters.
140
+
141
+ Returns:
142
+ computed_us: ``{id(p): orthogonalized_update}`` for owned params.
143
+ """
144
+ computed_us: dict[int, torch.Tensor | None] = {}
145
+ for p in owned_params:
146
+ u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
+ gathered_grads[id(p)] = None # free gathered grad
148
+ computed_us[id(p)] = u
149
+ return computed_us
150
+
151
+
152
+ def _launch_scatter(
153
+ params: list[DTensor],
154
+ owned_params: list[DTensor],
155
+ param_to_state: dict[int, _muon_state],
156
+ rank: int,
157
+ num_ranks: int,
158
+ process_group: dist.ProcessGroup,
159
+ computed_us: dict[int, torch.Tensor | None],
160
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
161
+ """Allocate scatter buffers, build send/recv, and launch async all-to-all.
162
+
163
+ Returns:
164
+ work: Async operation handle.
165
+ recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
+ scattered_us: ``{id(p): empty_local_tensor}`` for all params.
167
+ recv_counts: Per-source-rank element counts.
168
+ """
169
+ # Allocate scattered-u buffers
170
+ scattered_us: dict[int, torch.Tensor] = {}
171
+ for p in params:
172
+ scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
173
+
174
+ # Build send buffer (from computed_us on owner ranks)
175
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
+ send_counts = [0] * num_ranks
177
+
178
+ if owned_params:
179
+ for p in owned_params:
180
+ state = param_to_state[id(p)]
181
+
182
+ assert computed_us[id(p)] is not None
183
+ u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
+
185
+ total_sent = 0
186
+ for dst_rank in range(num_ranks):
187
+ indices = state.rank_indices[dst_rank]
188
+ su = u_full[indices].flatten()
189
+
190
+ n = su.numel()
191
+ assert n > 0
192
+
193
+ per_dst[dst_rank].append(su)
194
+ send_counts[dst_rank] += n
195
+ total_sent += n
196
+
197
+ assert total_sent == u_full.numel()
198
+
199
+ lengths = [len(v) for v in per_dst]
200
+ if all(l > 0 for l in lengths):
201
+ assert all(
202
+ l == lengths[0] for l in lengths
203
+ ), "All destination ranks must have the same number of sharded tensor"
204
+ per_dst_flat = [t for dst in per_dst for t in dst]
205
+ send_buf = torch.cat(per_dst_flat, dim=0)
206
+ else:
207
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
+
209
+ # Build recv buffer
210
+ recv_counts = [0] * num_ranks
211
+ for src in range(num_ranks):
212
+ total = 0
213
+ for p in params:
214
+ state = param_to_state[id(p)]
215
+ if state.worker_rank != src:
216
+ continue
217
+ total += state.rank_numels[rank]
218
+ recv_counts[src] = total
219
+
220
+ recv_total = sum(recv_counts)
221
+ assert recv_total > 0
222
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
+
224
+ # Launch async all-to-all
225
+ work = dist.all_to_all_single(
226
+ recv_buf,
227
+ send_buf,
228
+ output_split_sizes=recv_counts,
229
+ input_split_sizes=send_counts,
230
+ group=process_group,
231
+ async_op=True,
232
+ )
233
+
234
+ return work, recv_buf, scattered_us, recv_counts
235
+
236
+
237
+ def _complete_scatter(
238
+ recv_buf: torch.Tensor,
239
+ recv_counts: list[int],
240
+ params: list[DTensor],
241
+ param_to_state: dict[int, _muon_state],
242
+ rank: int,
243
+ scattered_us: dict[int, torch.Tensor],
244
+ ) -> None:
245
+ """Copy recv buffer into scattered_us (in-place)."""
246
+ off = 0
247
+ for src in range(len(recv_counts)):
248
+ block = recv_counts[src]
249
+ if block == 0:
250
+ continue
251
+
252
+ inner_off = 0
253
+ for p in params:
254
+ state = param_to_state[id(p)]
255
+ if state.worker_rank != src:
256
+ continue
257
+ n = state.rank_numels[rank]
258
+ assert n > 0
259
+
260
+ flat_local = recv_buf.narrow(0, off + inner_off,
261
+ n).view_as(p.to_local())
262
+ scattered_us[id(p)].copy_(flat_local)
263
+
264
+ inner_off += n
265
+
266
+ assert inner_off == block
267
+ off += block
268
+
269
+
270
+ def _update_params(
271
+ params: list[DTensor],
272
+ param_to_state: dict[int, _muon_state],
273
+ rank: int,
274
+ scattered_us: dict[int, torch.Tensor],
275
+ lr: float,
276
+ weight_decay: float,
277
+ ) -> None:
278
+ """Apply weight decay, Muon update, and optional QK clipping."""
279
+ for p in params:
280
+ state = param_to_state[id(p)]
281
+ u_dtensor = DTensor.from_local(
282
+ scattered_us[id(p)],
283
+ placements=p.placements,
284
+ device_mesh=p.device_mesh,
285
+ )
286
+
287
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
+ update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
289
+
290
+ # QK clipping – applied directly on the local tensor to
291
+ # avoid DTensor sharding-propagation issues with _StridedShard.
292
+ scales_full = compute_scales(
293
+ p,
294
+ state.qk_clip_state) if state.qk_clip_state is not None else None
295
+ if scales_full is not None:
296
+ ratio = p.shape[0] // scales_full.shape[0]
297
+ idx0 = state.rank_indices[rank][0]
298
+ if isinstance(idx0, slice):
299
+ start = idx0.start or 0
300
+ idx0 = torch.arange(start,
301
+ idx0.stop,
302
+ device=scales_full.device)
303
+ row_scales = scales_full[idx0 // ratio]
304
+ p._local_tensor.mul_(row_scales.view(-1, 1))
305
+
306
+
307
+ # ======================================================================
308
+ # Main generator – thin orchestrator that wires stages together.
309
+ # ======================================================================
310
+
311
+
312
+ @torch.no_grad()
313
+ def muon_chunk_pipeline(
314
+ params: list[DTensor],
315
+ param_to_state: dict[int, _muon_state],
316
+ rank: int,
317
+ ns_steps: int,
318
+ lr: float,
319
+ weight_decay: float,
320
+ none_grad: bool,
321
+ ) -> Generator[None, None, None]:
322
+ """Process one chunk of parameters through the full Muon pipeline.
323
+
324
+ Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
325
+
326
+ Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
327
+ that communication and computation overlap across chunks. Async
328
+ communication is launched via ``async_op=True`` and completed after
329
+ the yield with ``work.wait()``.
330
+
331
+ Overlap happens because :func:`run_pipeline` admits one new chunk
332
+ per iteration (staggered admission). While chunk *N* does NS
333
+ compute on the default CUDA stream, chunk *N+1*'s async all-to-all
334
+ runs concurrently on the NCCL stream — no separate ``comm_stream``
335
+ is required.
336
+
337
+ Yields exactly **2** times:
338
+
339
+ 1. After launching async all-to-all gather.
340
+ 2. After launching async all-to-all scatter.
341
+ """
342
+ process_group = param_to_state[id(params[0])].process_group
343
+ num_ranks = dist.get_world_size(group=process_group)
344
+ owned_params = [
345
+ p for p in params if param_to_state[id(p)].worker_rank == rank
346
+ ]
347
+
348
+ # Stages 1-2: launch async gather.
349
+ with record_function("muon::launch_gather"):
350
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
+ params, owned_params, param_to_state, rank, num_ranks,
352
+ process_group)
353
+
354
+ if none_grad:
355
+ for p in params:
356
+ p.grad = None
357
+
358
+ yield # --- YIELD 1: other chunks can launch their gather ---
359
+
360
+ with record_function("muon::wait_gather"):
361
+ work.wait()
362
+ _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
363
+ param_to_state, rank)
364
+ del recv_buf
365
+
366
+ # Stage 3: Newton-Schulz orthogonalization.
367
+ with record_function("muon::newton_schulz"):
368
+ computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
369
+ gathered_grads.clear()
370
+
371
+ # Stages 4-5: launch async scatter.
372
+ with record_function("muon::launch_scatter"):
373
+ work, recv_buf, scattered_us, recv_counts = _launch_scatter(
374
+ params, owned_params, param_to_state, rank, num_ranks,
375
+ process_group, computed_us)
376
+ computed_us.clear()
377
+
378
+ yield # --- YIELD 2: other chunks can launch their scatter ---
379
+
380
+ with record_function("muon::wait_scatter"):
381
+ work.wait()
382
+ _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
383
+ scattered_us)
384
+ del recv_buf
385
+
386
+ # Stage 6: apply parameter updates.
387
+ with record_function("muon::update_params"):
388
+ _update_params(params, param_to_state, rank, scattered_us, lr,
389
+ weight_decay)
390
+ scattered_us.clear()
build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
12
+ """
13
+ Parse a parameter name to check if it is a query/key projection layer
14
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
15
+
16
+ Returns:
17
+ (kind, layer_idx) or (None, -1) if not matched.
18
+
19
+ Example:
20
+ 'model.3.attn.wq.weight' -> ('wq', 3)
21
+ 'model.5.attn.wk.weight' -> ('wk', 5)
22
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
23
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
+ 'model.4.attn.v_proj.weight' -> (None, -1)
25
+ """
26
+ parts = name.split('.')
27
+ if len(parts) < 3:
28
+ return None, -1
29
+
30
+ kind = parts[-2]
31
+
32
+ layer_idx = -1
33
+ for part in reversed(parts):
34
+ if part.isdigit():
35
+ layer_idx = int(part)
36
+ break
37
+
38
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
39
+ return kind, layer_idx
40
+
41
+ return None, -1
42
+
43
+
44
+ @dataclass
45
+ class QKClipInfo:
46
+ """Per-parameter dynamic info computed from config + runtime logits."""
47
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
48
+ indices: list[int] # which heads to consider for clipping
49
+ head_dim: int # from config
50
+ threshold: float # from config
51
+ logit: torch.Tensor | None
52
+
53
+
54
+ def get_qk_clip_info(clip_config, n, qk_logits):
55
+ """Extract QK clipping info for a named parameter.
56
+
57
+ Args:
58
+ clip_config: QK clipping configuration dict (or None).
59
+ n: Parameter name string.
60
+ qk_logits: Dict mapping layer indices to logit tensors (or None).
61
+
62
+ Returns:
63
+ QKClipInfo instance with clipping configuration for this parameter.
64
+ """
65
+ if clip_config is None:
66
+ return None
67
+
68
+ head_dim = clip_config.get('head_dim')
69
+ threshold = clip_config.get('threshold')
70
+ kind, layer_idx = parse_qk_layer(n)
71
+
72
+ logit, indices = None, []
73
+ if qk_logits is not None and kind is not None:
74
+ logit = qk_logits[layer_idx]
75
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
76
+ indices = clip_config.get(indices_key, []) or []
77
+
78
+ if isinstance(logit, DTensor):
79
+ # In TP settings, qk_logits may be DTensor
80
+ # We convert it to full tensor here for simplicity
81
+ logit = logit.full_tensor()
82
+
83
+ return QKClipInfo(
84
+ kind=kind,
85
+ indices=indices,
86
+ head_dim=head_dim,
87
+ threshold=threshold,
88
+ logit=logit,
89
+ )
90
+
91
+
92
+ def compute_scales(p, qk_clip_state):
93
+ """Compute per-head scaling factors for QK clipping.
94
+
95
+ Returns scales tensor if any head exceeds threshold, else None.
96
+ """
97
+ kind = qk_clip_state.kind
98
+ indices = qk_clip_state.indices
99
+ head_dim = qk_clip_state.head_dim
100
+ threshold = qk_clip_state.threshold
101
+ logit = qk_clip_state.logit
102
+
103
+ H_global = p.shape[0] // head_dim
104
+ scales_full = torch.ones(H_global, device=p.data.device)
105
+ scaling = 0
106
+
107
+ for logit_idx, head_idx in enumerate(indices):
108
+ v_ele = float(logit[logit_idx])
109
+ if v_ele > threshold:
110
+ new_scale = math.sqrt(threshold / v_ele)
111
+ if new_scale < scales_full[head_idx]:
112
+ scales_full[head_idx] = new_scale
113
+ logger.info(
114
+ f"[{kind}] Head {head_idx} exceeded threshold "
115
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
+ )
117
+ scaling += 1
118
+
119
+ return scales_full if scaling > 0 else None
120
+
121
+
122
+ def qk_clip(p, scales, head_dim):
123
+ """Apply per-head scaling to a Q/K projection weight matrix."""
124
+ if isinstance(p, torch.nn.Parameter):
125
+ W = p.data.view(-1, head_dim, p.data.shape[1])
126
+ W.mul_(scales.view(-1, 1, 1))
127
+ else:
128
+ W = p.view(-1, head_dim, p.shape[1])
129
+ W.mul_(scales.view(-1, 1, 1))
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_7aef62f_dirty
3
+ ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_7aef62f_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0
3
  size 2004144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4919c48c77c6223dbf668f1461bcec175ef1bd6ea4cec8c2509de12ca7200a62
3
  size 2004144
build/torch210-cxx11-cu128-x86_64-linux/adamw.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch.distributed.tensor import DTensor
6
+
7
+
8
+ def fused_adamw(
9
+ params: list[torch.Tensor],
10
+ grads: list[torch.Tensor],
11
+ exp_avgs: list[torch.Tensor],
12
+ exp_avg_sqs: list[torch.Tensor],
13
+ max_exp_avg_sqs: list[torch.Tensor],
14
+ state_steps: list[torch.Tensor],
15
+ amsgrad: bool,
16
+ beta1: float,
17
+ beta2: float,
18
+ lr: float | torch.Tensor,
19
+ weight_decay: float,
20
+ eps: float,
21
+ maximize: bool,
22
+ ) -> None:
23
+ if not params:
24
+ return
25
+
26
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
27
+ # treating it as a scalar.
28
+ lr_dict: dict | None = ({
29
+ lr.device: lr
30
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
31
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
32
+ [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
33
+ state_steps] # type: ignore[list-item]
34
+ )
35
+ for (device, _), (
36
+ (
37
+ device_params_,
38
+ device_grads_,
39
+ device_exp_avgs_,
40
+ device_exp_avg_sqs_,
41
+ device_max_exp_avg_sqs,
42
+ device_state_steps_,
43
+ ),
44
+ _,
45
+ ) in grouped_tensors.items():
46
+ device_params = cast(list[torch.Tensor], device_params_)
47
+ device_grads = cast(list[torch.Tensor], device_grads_)
48
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
49
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
50
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
51
+
52
+ if lr_dict is not None and device not in lr_dict:
53
+ lr_dict[device] = lr.to(
54
+ device=device, non_blocking=True) # type: ignore[union-attr]
55
+ lr = lr_dict[device]
56
+ torch._foreach_add_(device_state_steps, 1)
57
+ func = torch._fused_adamw_
58
+ func(
59
+ device_params,
60
+ device_grads,
61
+ device_exp_avgs,
62
+ device_exp_avg_sqs,
63
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
64
+ device_state_steps,
65
+ amsgrad=amsgrad,
66
+ lr=lr, # type: ignore[arg-type]
67
+ beta1=beta1,
68
+ beta2=beta2,
69
+ weight_decay=weight_decay,
70
+ eps=eps,
71
+ maximize=maximize,
72
+ )
73
+
74
+
75
+ def step_adamw_params(optimizer_state, params, group):
76
+ """Run fused AdamW on a list of parameters sharing the same placement.
77
+
78
+ Args:
79
+ optimizer_state: The optimizer's state dict (self.state in Muon).
80
+ params: List of parameters to update.
81
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
+ """
83
+ params_with_grads = []
84
+ grads = []
85
+ moment1 = []
86
+ moment2 = []
87
+ max_exp_avg_sqs = []
88
+ state_steps = []
89
+ lr = group["lr"]
90
+ beta1, beta2 = group["adamw_betas"]
91
+ eps = group["adamw_eps"]
92
+ weight_decay = group["weight_decay"]
93
+
94
+ for p in params:
95
+ g = p.grad
96
+ if g is None:
97
+ continue
98
+ state = optimizer_state[p]
99
+ params_with_grads.append(p)
100
+ grads.append(g)
101
+ if "step" not in state:
102
+ state["step"] = (torch.zeros((),
103
+ dtype=torch.float32,
104
+ device=p.device))
105
+ state["moment1"] = torch.zeros_like(g)
106
+ state["moment2"] = torch.zeros_like(g)
107
+ moment1.append(state["moment1"])
108
+ moment2.append(state["moment2"])
109
+ if not isinstance(state["step"], torch.Tensor):
110
+ step_tensor = torch.tensor(state["step"],
111
+ dtype=torch.float32,
112
+ device=p.device)
113
+ else:
114
+ step_tensor = state["step"]
115
+ state_steps.append(step_tensor)
116
+
117
+ fused_adamw(
118
+ params_with_grads,
119
+ grads,
120
+ moment1,
121
+ moment2,
122
+ max_exp_avg_sqs,
123
+ state_steps,
124
+ amsgrad=False,
125
+ beta1=beta1,
126
+ beta2=beta2,
127
+ lr=lr,
128
+ weight_decay=weight_decay,
129
+ eps=eps,
130
+ maximize=False,
131
+ )
132
+
133
+
134
+ def step_adamw(optimizer_state, group):
135
+ """Dispatch AdamW step, grouping parameters by type and placement.
136
+
137
+ Args:
138
+ optimizer_state: The optimizer's state dict (self.state in Muon).
139
+ group: Parameter group dict.
140
+ """
141
+ params = group["params"]
142
+
143
+ # group params with its type and placement
144
+ placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
+ for p in params:
146
+ match p:
147
+ case DTensor():
148
+ placement_to_params[tuple([p.placements,
149
+ p.device_mesh])].append(p)
150
+ case torch.Tensor():
151
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
152
+
153
+ for group_params in placement_to_params.values():
154
+ step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu128-x86_64-linux/async_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class _Task:
8
+ """Internal: wraps a generator, advances one yield at a time."""
9
+
10
+ def __init__(self, generator: Generator[None, None, None], index: int):
11
+ self._generator = generator
12
+ self._index = index
13
+ self._steps_completed = 0
14
+ self.step() # run to first yield
15
+
16
+ def step(self) -> bool:
17
+ try:
18
+ next(self._generator)
19
+ self._steps_completed += 1
20
+ logger.debug("pipeline[%d] completed stage %d", self._index,
21
+ self._steps_completed)
22
+ return True
23
+ except StopIteration:
24
+ logger.debug("pipeline[%d] finished after %d stages", self._index,
25
+ self._steps_completed)
26
+ return False
27
+
28
+ def close(self):
29
+ self._generator.close()
30
+
31
+
32
+ def run_pipeline(
33
+ pipelines: Generator[Generator[None, None, None], None, None],
34
+ max_concurrent: int,
35
+ ) -> None:
36
+ """Run generator-based pipelines with bounded concurrency.
37
+
38
+ Each pipeline is a generator that yields at stage boundaries.
39
+ The runtime interleaves pipelines so communication and computation
40
+ overlap across chunks.
41
+ """
42
+ if max_concurrent <= 0:
43
+ raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
+
45
+ have_new = True
46
+ task_index = 0
47
+ previous_tasks: list[_Task] = []
48
+
49
+ try:
50
+ while have_new or previous_tasks:
51
+ running_tasks: list[_Task] = []
52
+
53
+ # Admit one new pipeline per iteration (staggered admission).
54
+ # Admitting one at a time ensures that while chunk N does NS
55
+ # compute on the default stream, chunk N+1's NCCL all-to-all
56
+ # runs concurrently on the NCCL stream — creating real
57
+ # communication/computation overlap on the GPU.
58
+ if have_new and len(previous_tasks) < max_concurrent:
59
+ try:
60
+ gen = next(pipelines)
61
+ task = _Task(gen, task_index)
62
+ task_index += 1
63
+ running_tasks.append(task)
64
+ except StopIteration:
65
+ have_new = False
66
+
67
+ # Advance every previously-yielded task by one step.
68
+ for task in previous_tasks:
69
+ if task.step():
70
+ running_tasks.append(task)
71
+
72
+ previous_tasks = running_tasks
73
+ except BaseException:
74
+ # Clean up all in-flight generators to release GPU resources.
75
+ for task in previous_tasks:
76
+ task.close()
77
+ raise
build/torch210-cxx11-cu128-x86_64-linux/core.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed import ProcessGroup
7
+ from torch.distributed.tensor import DTensor
8
+
9
+
10
+ @dataclass
11
+ class _muon_state:
12
+ worker_rank: int
13
+ process_group: ProcessGroup
14
+ rank_indices: dict[int, tuple] # local_rank -> per-dim indices
15
+ rank_numels: dict[int, int] # local_rank -> numel
16
+ name: str
17
+ qk_clip_state: torch.Tensor | None = None
18
+
19
+
20
+ def update_g(optimizer_state, p, g, group, momentum):
21
+ """Apply momentum update to gradient.
22
+
23
+ Args:
24
+ optimizer_state: The optimizer's state dict (self.state in Muon).
25
+ p: Parameter tensor.
26
+ g: Gradient tensor.
27
+ group: Parameter group dict.
28
+ momentum: Momentum coefficient.
29
+
30
+ Returns:
31
+ Momentum-updated gradient tensor.
32
+ """
33
+ state = optimizer_state[p]
34
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
+ torch.add(g, buf, alpha=momentum, out=buf)
36
+ if group["nesterov"]:
37
+ g.add_(buf, alpha=momentum)
38
+ return g
39
+ return buf
40
+
41
+
42
+ def update_p(p, u, lr, adjusted_lr, weight_decay):
43
+ """Apply weight decay and orthogonalized update to parameter.
44
+
45
+ Args:
46
+ p: Parameter (torch.nn.Parameter or DTensor).
47
+ u: Orthogonalized update tensor.
48
+ lr: Base learning rate.
49
+ adjusted_lr: Size-adjusted learning rate.
50
+ weight_decay: Weight decay coefficient.
51
+ """
52
+ if isinstance(p, torch.nn.Parameter):
53
+ # apply weight decay
54
+ p.data.mul_(1 - lr * weight_decay)
55
+ # apply update
56
+ p.data.add_(u, alpha=-adjusted_lr)
57
+ else:
58
+ p.mul_(1 - lr * weight_decay)
59
+ p.add_(u, alpha=-adjusted_lr)
60
+
61
+
62
+ def adjust_lr_for_muon(lr, param_shape):
63
+ """Scale learning rate based on parameter matrix dimensions.
64
+
65
+ Args:
66
+ lr: Base learning rate.
67
+ param_shape: Shape of the parameter tensor.
68
+
69
+ Returns:
70
+ Adjusted learning rate.
71
+ """
72
+ A, B = param_shape[:2]
73
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
74
+ # as described in the paper
75
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
76
+ adjusted_lr = lr * adjusted_ratio
77
+ return adjusted_lr
78
+
79
+
80
+ def default_is_muon(name, x, expert_keys=None):
81
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
+ if any(key in name for key in skip_keys):
83
+ return False
84
+ effective_ndim = x.ndim
85
+ if expert_keys and any(key in name for key in expert_keys):
86
+ effective_ndim -= 1
87
+ return effective_ndim >= 2
88
+
89
+
90
+ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
91
+ if is_muon_func is None:
92
+ is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
+
94
+ muon_params, muon_names = [], []
95
+ non_muon_params = []
96
+
97
+ for n, p in model.named_parameters():
98
+ if not p.requires_grad:
99
+ continue
100
+ if is_muon_func(n, p):
101
+ muon_params.append(p)
102
+ muon_names.append(n)
103
+ else:
104
+ non_muon_params.append(p)
105
+
106
+ return [
107
+ {
108
+ "params": muon_params,
109
+ "names": muon_names,
110
+ "use_muon": True,
111
+ },
112
+ {
113
+ "params": non_muon_params,
114
+ "use_muon": False,
115
+ },
116
+ ]
build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py CHANGED
@@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
  _StridedShard)
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  def get_slices_of_dtensor(
11
  target: DTensor | torch.Tensor,
12
  local_rank: int,
13
  shard_mesh: DeviceMesh,
14
  shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
  """
17
- Get the slice of local tensor for a given rank from a tensor.
 
 
 
 
 
18
  Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
  shard_placements (tuple[Placement]): The shard placements.
23
- """
24
 
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
 
 
 
 
26
 
27
  # find the global rank of the local rank in the shard mesh
28
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
@@ -34,34 +52,75 @@ def get_slices_of_dtensor(
34
 
35
  assert len(rank_coords) == len(shard_placements)
36
 
 
 
 
 
37
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
  # left-to-right sharding. This is ensured by the sorting logic of
39
  # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
 
44
- num_ranks = shard_mesh.mesh.shape[i]
 
45
 
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
 
 
 
48
 
49
- if dim_size % num_ranks != 0:
50
  raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- return tuple(slices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
@@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
71
  def construct_shard_mesh(
72
  placements: tuple[Placement],
73
  mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
 
81
- assert mesh.mesh.device.type == 'cpu'
 
 
82
 
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
 
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
 
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
 
121
 
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
 
 
124
 
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
 
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
  else:
137
  shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
  assert len(shard_placements) == len(set(shard_placements))
142
 
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
  )
171
 
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
7
  _StridedShard)
8
 
9
 
10
+ def _is_shard(placement: Placement) -> bool:
11
+ """Check if a placement is a shard type (Shard or _StridedShard).
12
+
13
+ In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
+ ``placement.is_shard()`` returns False for _StridedShard. This helper
15
+ handles both old and new hierarchies.
16
+ """
17
+ return isinstance(placement, (Shard, _StridedShard))
18
+
19
+
20
  def get_slices_of_dtensor(
21
  target: DTensor | torch.Tensor,
22
  local_rank: int,
23
  shard_mesh: DeviceMesh,
24
  shard_placements: tuple[Placement],
25
+ ) -> tuple[slice | torch.Tensor, ...]:
26
  """
27
+ Get per-dimension indices for a given rank's shard of the target tensor.
28
+
29
+ Uses ``Shard.local_shard_size_and_offset`` and
30
+ ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
+ both contiguous and strided (non-contiguous) sharding.
32
+
33
  Args:
34
+ target (DTensor | torch.Tensor): The target tensor (for its shape).
35
+ local_rank (int): The local rank within the shard group.
36
+ shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
  shard_placements (tuple[Placement]): The shard placements.
 
38
 
39
+ Returns:
40
+ A tuple of indices (one per tensor dim). Each element is either:
41
+ - A ``slice`` (for contiguous or unsharded dims)
42
+ - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
+ """
44
 
45
  # find the global rank of the local rank in the shard mesh
46
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
 
52
 
53
  assert len(rank_coords) == len(shard_placements)
54
 
55
+ # Track per-shard-dim indices.
56
+ # None means "not yet sharded on this dim".
57
+ dim_indices: dict[int, torch.Tensor] = {}
58
+
59
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
  # left-to-right sharding. This is ensured by the sorting logic of
61
  # construct_shard_mesh function.
62
+ for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
+ zip(rank_coords, shard_placements)):
64
+ assert _is_shard(placement)
65
 
66
+ num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
+ shard_dim = placement.dim
68
 
69
+ # Current effective size on this dim (may already be sub-sharded)
70
+ if shard_dim in dim_indices:
71
+ curr_size = len(dim_indices[shard_dim])
72
+ else:
73
+ curr_size = target.size()[shard_dim]
74
 
75
+ if curr_size % num_chunks != 0:
76
  raise NotImplementedError(
77
+ f"Dimension size {curr_size} is not divisible "
78
+ f"by number of ranks {num_chunks} for shard "
79
+ f"placement on dim {shard_dim}. (shape: {target.shape})")
80
+
81
+ # Compute indices for this level of sharding
82
+ if isinstance(placement, _StridedShard):
83
+ _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
84
+ placement,
85
+ curr_size,
86
+ num_chunks,
87
+ rank_coord,
88
+ return_first_offset=False)
89
+ new_indices = torch.tensor(offsets, dtype=torch.long)
90
+ else:
91
+ shard_size, offset = Shard.local_shard_size_and_offset(
92
+ curr_size, num_chunks, rank_coord)
93
+ new_indices = torch.arange(offset,
94
+ offset + shard_size,
95
+ dtype=torch.long)
96
+
97
+ # Compose with previous indices on this dim
98
+ if shard_dim in dim_indices:
99
+ dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
100
+ else:
101
+ dim_indices[shard_dim] = new_indices
102
 
103
+ # Build result tuple
104
+ result: list[slice | torch.Tensor] = []
105
+ for d in range(len(target.size())):
106
+ if d not in dim_indices:
107
+ result.append(slice(None))
108
+ else:
109
+ indices = dim_indices[d]
110
+ # Convert contiguous indices to slice for efficiency
111
+ if len(indices) > 0:
112
+ start = indices[0].item()
113
+ expected = torch.arange(start,
114
+ start + len(indices),
115
+ dtype=torch.long)
116
+ if torch.equal(indices, expected):
117
+ result.append(slice(start, start + len(indices)))
118
+ else:
119
+ result.append(indices)
120
+ else:
121
+ result.append(slice(0, 0))
122
+
123
+ return tuple(result)
124
 
125
 
126
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
 
130
  def construct_shard_mesh(
131
  placements: tuple[Placement],
132
  mesh: DeviceMesh,
133
+ ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
134
+ """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
 
 
 
 
135
 
136
+ Given a DTensor's placements and device mesh, extracts the "shard group"
137
+ — the set of ranks that together hold all shards of the same replica —
138
+ and creates a ProcessGroup for all-to-all among them.
139
 
140
+ Steps:
141
+ 1. Sort placements: Replicate first, then Shard by (dim, granularity).
142
+ 2. Permute the mesh tensor to match the sorted order.
143
+ 3. Collapse Replicate dims list of shard sub-meshes (one per replica).
144
+ 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ Example — 8 GPUs, mesh shape (2, 2, 2),
147
+ placements ``[Shard(0), Replicate, _StridedShard(0)]``::
 
 
148
 
149
+ Step 1 Sort: [Replicate, _StridedShard(0), Shard(0)]
150
+ Permutation: [1, 2, 0]
151
 
152
+ Step 2 Permute mesh dims by [1, 2, 0]:
153
+ Original: Permuted:
154
+ [[[0,1],[2,3]], [[[0,2],[1,3]],
155
+ [[4,5],[6,7]]] [[4,6],[5,7]]]
156
 
157
+ Step 3 Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
158
+ sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
159
+ sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
160
+ shard_placements = (_StridedShard(0), Shard(0))
161
 
162
+ Step 4 Rank 0 → ProcessGroup([0,1,4,5])
163
+ Rank 2 ProcessGroup([2,3,6,7])
164
+
165
+ Returns:
166
+ ``(shard_mesh, process_group, shard_placements)``
167
+ """
168
+ my_rank = dist.get_rank()
169
+ assert mesh.mesh.device.type == 'cpu'
170
+
171
+ # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
172
+ # This avoids a non-collective dist.new_group() call, which would
173
+ # deadlock when only a subset of ranks call this function (e.g. expert
174
+ # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
175
+ if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
176
+ key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
177
+ if key not in _ranks_to_dist_cache:
178
+ _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
179
+ return (*_ranks_to_dist_cache[key], tuple(placements))
180
+
181
+ mesh_tensor = mesh.mesh.clone()
182
+
183
+ # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
184
+ # _StridedShard comes BEFORE regular Shard on the same dim so that
185
+ # get_slices_of_dtensor applies the outer sharding first, matching
186
+ # DTensor's left-to-right (outer-to-inner) composition order.
187
+ def _sort_key(item):
188
+ index, placement = item
189
+ assert not placement.is_partial(), "Partial placement not supported"
190
+ if placement.is_replicate():
191
+ return (-1, 0, index)
192
+ assert _is_shard(placement), f"Unsupported: {type(placement)}"
193
+ split = (-1 / placement.split_factor if isinstance(
194
+ placement, _StridedShard) else 0)
195
+ return (placement.dim, split, index)
196
+
197
+ indexed = sorted(enumerate(placements), key=_sort_key)
198
+ perm, sorted_placements = zip(*indexed)
199
+
200
+ # -- Step 2: Permute mesh to match sorted placement order. --------------
201
+ sorted_mesh = mesh_tensor.permute(perm)
202
+
203
+ # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
204
+ # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
205
+ num_rep = sum(1 for p in sorted_placements if p.is_replicate())
206
+ if num_rep > 0:
207
+ if num_rep > 1:
208
+ sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
209
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
210
  else:
211
  shard_meshes = [sorted_mesh]
212
+ shard_placements = sorted_placements[num_rep:]
 
 
213
  assert len(shard_placements) == len(set(shard_placements))
214
 
215
+ # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
216
+ # All ranks must call dist.new_group in the same order, even though each
217
+ # rank only joins one group.
218
+ def _cache_key(t: torch.Tensor) -> tuple:
219
+ return (*t.shape, *t.flatten().tolist())
220
+
221
+ my_key = None
222
+ for sm in shard_meshes:
223
+ key = _cache_key(sm)
224
+ if (my_rank == sm).any().item():
225
+ assert my_key is None, "Rank appears in multiple shard groups"
226
+ my_key = key
227
+ if key not in _ranks_to_dist_cache:
228
+ pg = dist.new_group(sm.flatten().tolist())
229
+ _ranks_to_dist_cache[key] = (
230
+ DeviceMesh(device_type="cuda", mesh=sm),
231
+ pg,
 
 
 
 
 
 
 
 
 
 
232
  )
233
 
234
+ return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json CHANGED
@@ -1 +1,3 @@
1
- {"python-depends":[]}
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -1,536 +1,121 @@
1
  import logging
2
- import math
3
  import types
4
  from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
 
 
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
 
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
 
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
 
 
406
 
407
- assert inner_off == block
408
- off += block
409
 
 
410
 
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
  """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
  continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
 
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- return None, -1
 
524
 
 
525
 
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
@@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer):
554
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
  weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
  adamw_lr: The learning rate for the internal AdamW.
559
  adamw_betas: The betas for the internal AdamW.
560
  adamw_eps: The epsilon for the internal AdamW.
@@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer):
564
  - "q_indices" (list[int]): Indices of query heads to consider.
565
  - "k_indices" (list[int]): Indices of key heads to consider.
566
  - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
  this value will be scaled down.
569
  Default is:
570
  {
@@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer):
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
 
 
 
 
 
 
 
587
  """
588
 
589
  def __init__(self,
@@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer):
597
  adamw_eps=1e-8,
598
  none_grad=True,
599
  debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
  warmup_step=5,
607
  chunk_size=-1,
608
  use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
 
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
@@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer):
630
 
631
  super().__init__(params, defaults)
632
 
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
  self.debug = debug
638
- self.clip_config = clip_config
 
 
 
 
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
  self.small_param_numel_threshold = small_param_numel_threshold
 
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
@@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer):
649
 
650
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
 
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
  def get_shard_mesh(self, p):
667
  """
668
  Get the shard mesh for a parameter p on the given rank.
@@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer):
673
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
  p.placements, p.device_mesh)
675
 
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
  return shard_mesh, shard_pg, shard_placements
680
 
681
  def init_state_and_assign_params(self, names, params, group, qk_logits):
@@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer):
694
  total_flops += flops
695
 
696
  if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
 
700
  paired = list(zip(names, params))
701
 
@@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer):
724
 
725
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  param_to_state[id(p)] = _muon_state(
730
  worker_rank=worker_rank,
731
  process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
  name=n,
735
  qk_clip_state=qk_clip_state,
736
  )
737
 
738
  return param_to_state, ordered_params
739
 
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
  for n, p in zip(names, params):
744
  g = p.grad
745
  if g is None:
746
  continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
755
 
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
 
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
 
761
- scales_full = self._compute_scales(
762
  p, qk_clip_state) if qk_clip_state is not None else None
763
  if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
 
766
  def distributed_muon(
767
  self,
@@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer):
770
  group: dict[str, Any],
771
  lr: float,
772
  weight_decay: float,
773
- momentum: float,
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
777
 
 
778
  for n, p in zip(names, params):
779
  g = p.grad
780
  if g is None:
781
  continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
@@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer):
796
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
  steps=group["ns_steps"])
798
 
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
 
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
 
804
- scales_full = self._compute_scales(
805
  p_full, qk_clip_state) if qk_clip_state is not None else None
806
 
807
  if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
  ndims = len(p.device_mesh.mesh.shape)
@@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer):
822
 
823
  p.copy_(p_sharded)
824
 
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
  """
911
  Perform a parallel optimization step using Muon.
912
- """
913
 
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
 
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
 
930
  param_to_state, ordered_params = self.init_state_and_assign_params(
931
  names, params, group, qk_logits)
932
 
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
 
967
  if self.chunk_size == -1:
968
  shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
  elif self.chunk_size > 0:
972
  chunk_size = self.chunk_size
973
  else:
974
  raise ValueError("chunk_size must be -1 or a positive integer.")
975
 
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
 
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
 
1064
  def _step_muon(self, group, qk_logits=None):
1065
  params = group["params"]
@@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer):
1068
  momentum = group["momentum"]
1069
  names = group["names"]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  param_dtensors = []
1072
  name_dtensors = []
1073
 
@@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer):
1083
  group=group,
1084
  lr=lr,
1085
  weight_decay=weight_decay,
1086
- momentum=momentum,
1087
  qk_logits=qk_logits)
1088
  return
1089
 
@@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer):
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
  assert len(dtensors) == len(names)
1125
  for p, n in zip(dtensors, names):
@@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer):
1141
  group=group,
1142
  lr=lr,
1143
  weight_decay=weight_decay,
1144
- momentum=momentum,
1145
  qk_logits=qk_logits,
1146
  )
1147
 
@@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer):
1159
  group,
1160
  lr=lr,
1161
  weight_decay=weight_decay,
1162
- momentum=momentum,
1163
  qk_logits=qk_logits,
1164
  )
1165
 
@@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer):
1170
  group,
1171
  lr=lr,
1172
  weight_decay=weight_decay,
1173
- momentum=momentum,
1174
  qk_logits=qk_logits,
1175
  )
1176
 
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
  @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
@@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer):
1249
  Args:
1250
  closure (Callable, optional): A closure that reevaluates the model
1251
  and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
  (1 / sqrt(head_dim)) * (Q @ K^T).
1256
  """
1257
  loss = None
@@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer):
1263
  if group["use_muon"]:
1264
  self._step_muon(group, qk_logits=qk_logits)
1265
  else:
1266
- self._step_adamw(group)
1267
 
1268
  return loss
 
1
  import logging
 
2
  import types
3
  from collections import defaultdict
4
+ from typing import Any
 
5
 
6
  import torch
7
  import torch.distributed as dist
8
+ from torch.distributed.tensor import DTensor, Replicate, Shard
9
+ from torch.profiler import record_function
10
+
11
+ from .adamw import step_adamw
12
+ from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, update_g, update_p)
15
+ from .distributed.utils import (_is_shard, construct_shard_mesh,
16
+ get_slices_of_dtensor)
17
+ from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
+ _zeropower_via_newtonschulz5)
19
+ from .pipeline import muon_chunk_pipeline
20
+ from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def _expand_expert_params(names, params, expert_keys):
26
+ """Expand expert params by splitting on dim 0 (expert dimension).
 
27
 
28
+ Params whose name matches any key in ``expert_keys`` are treated as
29
+ expert-parallel tensors. Their outermost dimension is the expert
30
+ dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
31
+ ``nn.Parameter`` views so that in-place updates propagate back to
32
+ the original storage.
33
 
34
+ Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
35
+ if they are expert params, their key must be added to ``expert_keys``.
36
 
37
+ The grad must already be set on each expert param (e.g. after momentum).
38
 
39
+ For DTensor expert params, placements that shard on dim 0 (expert dim)
40
+ are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
41
+ preserved: each 2D slice is wrapped as a DTensor on the corresponding
42
+ submesh so the parallel pipeline handles the TP communication.
 
43
  """
44
+ expanded_names = []
45
+ expanded_params = []
46
+
47
+ for n, p in zip(names, params):
48
+ is_expert = expert_keys and any(key in n for key in expert_keys)
49
+ is_dtensor = isinstance(p.data, DTensor)
50
+
51
+ if not is_expert:
52
+ assert p.data.ndim <= 2, (
53
+ f"Param {n} has ndim={p.data.ndim} but does not match "
54
+ f"expert_keys={expert_keys}. If this is an expert param, "
55
+ f"add its key to expert_keys.")
56
+ expanded_names.append(n)
57
+ expanded_params.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ g = p.grad
61
+ assert g is not None, (
62
+ f"Expert param {n} must have grad set before expansion")
63
+
64
+ tp_mesh = None
65
+ tp_placements_2d = None
66
+
67
+ if is_dtensor:
68
+ local_data = p.to_local()
69
+ local_grad = g.to_local() if isinstance(g, DTensor) else g
70
+
71
+ # Find non-dim-0 shard placements (e.g. TP sharding).
72
+ # After splitting on dim 0, Shard(k) becomes Shard(k-1).
73
+ tp_dim_indices = []
74
+ tp_placements_2d = []
75
+ for i, pl in enumerate(p.placements):
76
+ if _is_shard(pl) and pl.dim != 0:
77
+ tp_dim_indices.append(i)
78
+ tp_placements_2d.append(Shard(pl.dim - 1))
79
+
80
+ if tp_dim_indices:
81
+ tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
82
+ for i in tp_dim_indices)
83
+ if len(tp_dim_names) == 1:
84
+ tp_mesh = p.device_mesh[tp_dim_names[0]]
85
+ else:
86
+ tp_mesh = p.device_mesh[tp_dim_names]
87
+ else:
88
+ local_data = p.data
89
+ local_grad = g
90
+
91
+ # Expand: split dim 0, reshape each slice to 2D.
92
+ num_local_experts = local_data.shape[0]
93
+ for i in range(num_local_experts):
94
+ slice_data = local_data[i]
95
+ slice_grad = local_grad[i]
96
+
97
+ if tp_mesh is not None:
98
+ # Wrap as DTensor on TP submesh so the pipeline handles
99
+ # TP communication (gather/scatter across TP ranks).
100
+ dt_data = DTensor.from_local(slice_data,
101
+ device_mesh=tp_mesh,
102
+ placements=tp_placements_2d)
103
+ dt_grad = DTensor.from_local(slice_grad,
104
+ device_mesh=tp_mesh,
105
+ placements=tp_placements_2d)
106
+ expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
107
+ expert_param.grad = dt_grad
108
+ else:
109
+ expert_param = torch.nn.Parameter(slice_data,
110
+ requires_grad=False)
111
+ expert_param.grad = slice_grad
112
 
113
+ expanded_names.append(f"{n}[{i}]")
114
+ expanded_params.append(expert_param)
115
 
116
+ p.grad = None # allow expert grad storage to be freed after pipeline
117
 
118
+ return expanded_names, expanded_params
 
 
 
 
 
 
 
119
 
120
 
121
  class Muon(torch.optim.Optimizer):
 
139
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
140
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
141
  weight_decay: The weight decay for Muon and AdamW.
142
+ Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
143
  adamw_lr: The learning rate for the internal AdamW.
144
  adamw_betas: The betas for the internal AdamW.
145
  adamw_eps: The epsilon for the internal AdamW.
 
149
  - "q_indices" (list[int]): Indices of query heads to consider.
150
  - "k_indices" (list[int]): Indices of key heads to consider.
151
  - "head_dim" (int): Dimensionality of each attention head.
152
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
153
  this value will be scaled down.
154
  Default is:
155
  {
 
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
+ expert_keys: List of strings to identify expert-parallel parameters.
173
+ If any key appears in a parameter's name, its outermost
174
+ dimension is treated as the expert dimension and expanded
175
+ into per-expert 2D params for Muon. For example,
176
+ ``expert_keys=["experts"]`` matches any param whose name
177
+ contains "experts". 3D+ params not matched by any key
178
+ will raise an error.
179
  """
180
 
181
  def __init__(self,
 
189
  adamw_eps=1e-8,
190
  none_grad=True,
191
  debug=False,
192
+ clip_config=None,
 
 
 
 
 
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
+ small_param_numel_threshold=65536,
197
+ expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
 
218
 
219
  super().__init__(params, defaults)
220
 
 
 
 
 
221
  self.debug = debug
222
+ self.clip_config = clip_config if clip_config is not None else {
223
+ "q_indices": [],
224
+ "k_indices": [],
225
+ "head_dim": 128,
226
+ "threshold": 100,
227
+ }
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
  self.small_param_numel_threshold = small_param_numel_threshold
232
+ self.expert_keys = expert_keys
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
 
239
 
240
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def get_shard_mesh(self, p):
243
  """
244
  Get the shard mesh for a parameter p on the given rank.
 
249
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
250
  p.placements, p.device_mesh)
251
 
 
 
 
252
  return shard_mesh, shard_pg, shard_placements
253
 
254
  def init_state_and_assign_params(self, names, params, group, qk_logits):
 
267
  total_flops += flops
268
 
269
  if self.debug:
270
+ logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
271
+ total_flops / 1e12)
272
 
273
  paired = list(zip(names, params))
274
 
 
297
 
298
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
299
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
300
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
301
+
302
+ # Precompute per-rank indices and numels for all-to-all.
303
+ rank_indices: dict[int, tuple] = {}
304
+ rank_numels: dict[int, int] = {}
305
+ for r in range(num_ranks):
306
+ indices = get_slices_of_dtensor(p, r, shard_mesh,
307
+ shard_placements)
308
+ rank_indices[r] = indices
309
+ numel = 1
310
+ for idx, dim_size in zip(indices, p.shape):
311
+ if isinstance(idx, slice):
312
+ start, stop, step = idx.indices(dim_size)
313
+ numel *= max(0, (stop - start + (step - 1)) // step)
314
+ else:
315
+ numel *= len(idx)
316
+ rank_numels[r] = numel
317
 
318
  param_to_state[id(p)] = _muon_state(
319
  worker_rank=worker_rank,
320
  process_group=shard_pg,
321
+ rank_indices=rank_indices,
322
+ rank_numels=rank_numels,
323
  name=n,
324
  qk_clip_state=qk_clip_state,
325
  )
326
 
327
  return param_to_state, ordered_params
328
 
329
+ def base(self, names, params, group, lr, weight_decay, qk_logits):
330
+ # Momentum is already applied by _step_muon before this method.
 
331
  for n, p in zip(names, params):
332
  g = p.grad
333
  if g is None:
334
  continue
 
 
 
 
 
335
 
336
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
  steps=group["ns_steps"])
338
 
339
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
+ update_p(p, u, lr, adjusted_lr, weight_decay)
341
 
342
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
343
 
344
+ scales_full = compute_scales(
345
  p, qk_clip_state) if qk_clip_state is not None else None
346
  if scales_full is not None:
347
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
348
 
349
  def distributed_muon(
350
  self,
 
353
  group: dict[str, Any],
354
  lr: float,
355
  weight_decay: float,
 
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
  """ Implementation of Distributed Muon by Liu et al. """
359
 
360
+ # Momentum is already applied by _step_muon before this method.
361
  for n, p in zip(names, params):
362
  g = p.grad
363
  if g is None:
364
  continue
 
 
 
 
 
365
 
366
  # Gather G
367
  if isinstance(p.data, DTensor):
 
374
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
  steps=group["ns_steps"])
376
 
377
+ adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
+ update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
381
 
382
+ scales_full = compute_scales(
383
  p_full, qk_clip_state) if qk_clip_state is not None else None
384
 
385
  if scales_full is not None:
386
+ qk_clip(p_full, scales_full, qk_clip_state.head_dim)
387
 
388
  if isinstance(p.data, DTensor):
389
  ndims = len(p.device_mesh.mesh.shape)
 
400
 
401
  p.copy_(p_sharded)
402
 
403
+ def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
 
406
 
407
+ Parameters are chunked and each chunk is processed by a
408
+ :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
409
+ interleaves multiple chunks so that communication and computation
410
+ overlap across chunks (the same overlap previously achieved by the
411
+ warmup + main-loop index scheduling).
412
+ """
413
 
414
+ # Momentum is already applied by _step_muon before this method.
 
 
 
 
 
 
 
415
 
416
  param_to_state, ordered_params = self.init_state_and_assign_params(
417
  names, params, group, qk_logits)
418
 
419
+ # Compute local rank for this group's shard process group.
420
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
421
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  if self.chunk_size == -1:
424
  shard_ranks = dist.get_world_size(param_to_state[id(
425
+ ordered_params[0])].process_group)
426
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
  elif self.chunk_size > 0:
428
  chunk_size = self.chunk_size
429
  else:
430
  raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
+ def pipelines():
433
+ for start in range(0, len(ordered_params), chunk_size):
434
+ chunk = ordered_params[start:start + chunk_size]
435
+ if chunk:
436
+ yield muon_chunk_pipeline(
437
+ params=chunk,
438
+ param_to_state=param_to_state,
439
+ rank=rank,
440
+ ns_steps=group["ns_steps"],
441
+ lr=lr,
442
+ weight_decay=weight_decay,
443
+ none_grad=group["none_grad"],
444
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ with record_function("muon::barrier"):
447
+ dist.barrier()
448
+ with record_function("muon::pipeline"):
449
+ run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  def _step_muon(self, group, qk_logits=None):
452
  params = group["params"]
 
455
  momentum = group["momentum"]
456
  names = group["names"]
457
 
458
+ # Apply momentum to all params before routing/expansion.
459
+ with record_function("muon::momentum"):
460
+ for n, p in zip(names, params):
461
+ g = p.grad
462
+ if g is None:
463
+ continue
464
+ g = update_g(self.state, p, g, group, momentum)
465
+ p.grad = g
466
+
467
+ # Expand expert params by splitting on dim 0.
468
+ names, params = _expand_expert_params(names, params, self.expert_keys)
469
+
470
  param_dtensors = []
471
  name_dtensors = []
472
 
 
482
  group=group,
483
  lr=lr,
484
  weight_decay=weight_decay,
 
485
  qk_logits=qk_logits)
486
  return
487
 
 
517
  # and run parallel Muon on each group.
518
 
519
  placement_to_params = defaultdict(lambda: ([], []))
 
520
 
521
  assert len(dtensors) == len(names)
522
  for p, n in zip(dtensors, names):
 
538
  group=group,
539
  lr=lr,
540
  weight_decay=weight_decay,
 
541
  qk_logits=qk_logits,
542
  )
543
 
 
555
  group,
556
  lr=lr,
557
  weight_decay=weight_decay,
 
558
  qk_logits=qk_logits,
559
  )
560
 
 
565
  group,
566
  lr=lr,
567
  weight_decay=weight_decay,
 
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
 
575
  Args:
576
  closure (Callable, optional): A closure that reevaluates the model
577
  and returns the loss.
578
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
579
+ to 1D tensors of shape (num_heads,), representing the maximum
580
+ QK logits across all tokens, computed as
581
  (1 / sqrt(head_dim)) * (Q @ K^T).
582
  """
583
  loss = None
 
589
  if group["use_muon"]:
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
592
+ step_adamw(self.state, group)
593
 
594
  return loss
build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .matmul_transpose_triton import matmul_transpose_assign
4
+
5
+ COMM_DTYPE = torch.bfloat16
6
+ DEFAULT_CHUNK_SIZE_RATIO = 4
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ # Muon's Newton–Schulz iteration causes high variance in singular values
12
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
13
+ @torch.no_grad()
14
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
+ def _zeropower_via_newtonschulz5(G, steps):
16
+ """
17
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
+ performance at all relative to UV^T, where USV^T = G is the SVD.
24
+ """
25
+ assert len(G.shape) == 2
26
+ assert G.dtype == COMM_DTYPE
27
+ X = G # no manual typecast
28
+
29
+ if G.size(0) > G.size(1):
30
+ X = X.T
31
+ # Ensure spectral norm is at most 1
32
+ X = X / (X.norm() + 1e-7)
33
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
+ # Perform the NS iterations
36
+ for a, b, c in [
37
+ (4.0848, -6.8946, 2.9270),
38
+ (3.9505, -6.3029, 2.6377),
39
+ (3.7418, -5.5913, 2.3037),
40
+ (2.8769, -3.1427, 1.2046),
41
+ (2.8366, -3.0525, 1.2012),
42
+ ]:
43
+ matmul_transpose_assign(X, buf1)
44
+ matmul_transpose_assign(buf1, buf2)
45
+ buf1.mul_(b).add_(buf2, alpha=c)
46
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
47
+
48
+ if G.size(0) > G.size(1):
49
+ X = X.T
50
+ return X
build/torch210-cxx11-cu128-x86_64-linux/pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ from .core import _muon_state, adjust_lr_for_muon, update_p
10
+ from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
+ from .qk_clip import compute_scales
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ======================================================================
16
+ # Stage helpers
17
+ # ======================================================================
18
+
19
+
20
+ def _launch_gather(
21
+ params: list[DTensor],
22
+ owned_params: list[DTensor],
23
+ param_to_state: dict[int, _muon_state],
24
+ rank: int,
25
+ num_ranks: int,
26
+ process_group: dist.ProcessGroup,
27
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
+ """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
+
30
+ Returns:
31
+ work: Async operation handle.
32
+ recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
+ gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
+ ``None`` for non-owned.
35
+ recv_counts: Per-source-rank element counts.
36
+ """
37
+ # Allocate gathered-grad buffers
38
+ gathered_grads: dict[int, torch.Tensor | None] = {}
39
+ for p in params:
40
+ state = param_to_state[id(p)]
41
+ if rank == state.worker_rank:
42
+ gathered_grads[id(p)] = torch.empty(p.shape,
43
+ dtype=COMM_DTYPE,
44
+ device="cuda")
45
+ else:
46
+ gathered_grads[id(p)] = None
47
+
48
+ # Build send buffer
49
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
+ send_counts = [0] * num_ranks
51
+
52
+ for p in params:
53
+ state = param_to_state[id(p)]
54
+ dst = state.worker_rank
55
+ assert dst < num_ranks
56
+ shard_elems = state.rank_numels[rank]
57
+ g = p.grad
58
+ g = g.to_local().to(COMM_DTYPE).contiguous()
59
+ assert g.numel() == shard_elems
60
+ per_dst[dst].append(g.view(-1))
61
+ send_counts[dst] += shard_elems
62
+
63
+ assert any(
64
+ len(v) > 0 for v in
65
+ per_dst), "At least one destination rank must receive a sharded tensor"
66
+ per_dst_flat = [t for dst in per_dst for t in dst]
67
+ send_buf = torch.cat(per_dst_flat, dim=0)
68
+
69
+ # Build recv buffer
70
+ recv_counts = [0] * num_ranks
71
+ for src in range(num_ranks):
72
+ total = 0
73
+ for p in owned_params:
74
+ state = param_to_state[id(p)]
75
+ assert state.worker_rank == rank
76
+ total += state.rank_numels[src]
77
+ recv_counts[src] = total
78
+
79
+ recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
80
+
81
+ # Launch async all-to-all
82
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
83
+ f"recv_buf size: {recv_buf.numel()}, "
84
+ f"recv_counts: {recv_counts}, "
85
+ f"send_counts: {send_counts}, "
86
+ f"process_group: {str(process_group)}")
87
+ work = dist.all_to_all_single(
88
+ recv_buf,
89
+ send_buf,
90
+ output_split_sizes=recv_counts,
91
+ input_split_sizes=send_counts,
92
+ group=process_group,
93
+ async_op=True,
94
+ )
95
+
96
+ return work, recv_buf, gathered_grads, recv_counts
97
+
98
+
99
+ def _complete_gather(
100
+ recv_buf: torch.Tensor,
101
+ recv_counts: list[int],
102
+ owned_params: list[DTensor],
103
+ gathered_grads: dict[int, torch.Tensor | None],
104
+ param_to_state: dict[int, _muon_state],
105
+ rank: int,
106
+ ) -> None:
107
+ """Reconstruct gathered grads from the recv buffer (in-place)."""
108
+ off = 0
109
+ for src in range(len(recv_counts)):
110
+ if recv_counts[src] == 0:
111
+ continue
112
+
113
+ block = recv_counts[src]
114
+ inner_off = 0
115
+ for p in owned_params:
116
+ state = param_to_state[id(p)]
117
+ assert state.worker_rank == rank
118
+
119
+ indices = state.rank_indices[src]
120
+
121
+ shard_view = gathered_grads[id(p)][indices]
122
+ n = shard_view.numel()
123
+ assert n > 0
124
+
125
+ sg = recv_buf.narrow(0, off + inner_off, n)
126
+ sg = sg.reshape(shard_view.shape)
127
+ gathered_grads[id(p)][indices] = sg
128
+
129
+ inner_off += n
130
+ assert inner_off == block
131
+ off += block
132
+
133
+
134
+ def _compute_ns(
135
+ owned_params: list[DTensor],
136
+ gathered_grads: dict[int, torch.Tensor | None],
137
+ ns_steps: int,
138
+ ) -> dict[int, torch.Tensor | None]:
139
+ """Run Newton-Schulz orthogonalization on owned parameters.
140
+
141
+ Returns:
142
+ computed_us: ``{id(p): orthogonalized_update}`` for owned params.
143
+ """
144
+ computed_us: dict[int, torch.Tensor | None] = {}
145
+ for p in owned_params:
146
+ u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
+ gathered_grads[id(p)] = None # free gathered grad
148
+ computed_us[id(p)] = u
149
+ return computed_us
150
+
151
+
152
+ def _launch_scatter(
153
+ params: list[DTensor],
154
+ owned_params: list[DTensor],
155
+ param_to_state: dict[int, _muon_state],
156
+ rank: int,
157
+ num_ranks: int,
158
+ process_group: dist.ProcessGroup,
159
+ computed_us: dict[int, torch.Tensor | None],
160
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
161
+ """Allocate scatter buffers, build send/recv, and launch async all-to-all.
162
+
163
+ Returns:
164
+ work: Async operation handle.
165
+ recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
+ scattered_us: ``{id(p): empty_local_tensor}`` for all params.
167
+ recv_counts: Per-source-rank element counts.
168
+ """
169
+ # Allocate scattered-u buffers
170
+ scattered_us: dict[int, torch.Tensor] = {}
171
+ for p in params:
172
+ scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
173
+
174
+ # Build send buffer (from computed_us on owner ranks)
175
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
+ send_counts = [0] * num_ranks
177
+
178
+ if owned_params:
179
+ for p in owned_params:
180
+ state = param_to_state[id(p)]
181
+
182
+ assert computed_us[id(p)] is not None
183
+ u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
+
185
+ total_sent = 0
186
+ for dst_rank in range(num_ranks):
187
+ indices = state.rank_indices[dst_rank]
188
+ su = u_full[indices].flatten()
189
+
190
+ n = su.numel()
191
+ assert n > 0
192
+
193
+ per_dst[dst_rank].append(su)
194
+ send_counts[dst_rank] += n
195
+ total_sent += n
196
+
197
+ assert total_sent == u_full.numel()
198
+
199
+ lengths = [len(v) for v in per_dst]
200
+ if all(l > 0 for l in lengths):
201
+ assert all(
202
+ l == lengths[0] for l in lengths
203
+ ), "All destination ranks must have the same number of sharded tensor"
204
+ per_dst_flat = [t for dst in per_dst for t in dst]
205
+ send_buf = torch.cat(per_dst_flat, dim=0)
206
+ else:
207
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
+
209
+ # Build recv buffer
210
+ recv_counts = [0] * num_ranks
211
+ for src in range(num_ranks):
212
+ total = 0
213
+ for p in params:
214
+ state = param_to_state[id(p)]
215
+ if state.worker_rank != src:
216
+ continue
217
+ total += state.rank_numels[rank]
218
+ recv_counts[src] = total
219
+
220
+ recv_total = sum(recv_counts)
221
+ assert recv_total > 0
222
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
+
224
+ # Launch async all-to-all
225
+ work = dist.all_to_all_single(
226
+ recv_buf,
227
+ send_buf,
228
+ output_split_sizes=recv_counts,
229
+ input_split_sizes=send_counts,
230
+ group=process_group,
231
+ async_op=True,
232
+ )
233
+
234
+ return work, recv_buf, scattered_us, recv_counts
235
+
236
+
237
+ def _complete_scatter(
238
+ recv_buf: torch.Tensor,
239
+ recv_counts: list[int],
240
+ params: list[DTensor],
241
+ param_to_state: dict[int, _muon_state],
242
+ rank: int,
243
+ scattered_us: dict[int, torch.Tensor],
244
+ ) -> None:
245
+ """Copy recv buffer into scattered_us (in-place)."""
246
+ off = 0
247
+ for src in range(len(recv_counts)):
248
+ block = recv_counts[src]
249
+ if block == 0:
250
+ continue
251
+
252
+ inner_off = 0
253
+ for p in params:
254
+ state = param_to_state[id(p)]
255
+ if state.worker_rank != src:
256
+ continue
257
+ n = state.rank_numels[rank]
258
+ assert n > 0
259
+
260
+ flat_local = recv_buf.narrow(0, off + inner_off,
261
+ n).view_as(p.to_local())
262
+ scattered_us[id(p)].copy_(flat_local)
263
+
264
+ inner_off += n
265
+
266
+ assert inner_off == block
267
+ off += block
268
+
269
+
270
+ def _update_params(
271
+ params: list[DTensor],
272
+ param_to_state: dict[int, _muon_state],
273
+ rank: int,
274
+ scattered_us: dict[int, torch.Tensor],
275
+ lr: float,
276
+ weight_decay: float,
277
+ ) -> None:
278
+ """Apply weight decay, Muon update, and optional QK clipping."""
279
+ for p in params:
280
+ state = param_to_state[id(p)]
281
+ u_dtensor = DTensor.from_local(
282
+ scattered_us[id(p)],
283
+ placements=p.placements,
284
+ device_mesh=p.device_mesh,
285
+ )
286
+
287
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
+ update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
289
+
290
+ # QK clipping – applied directly on the local tensor to
291
+ # avoid DTensor sharding-propagation issues with _StridedShard.
292
+ scales_full = compute_scales(
293
+ p,
294
+ state.qk_clip_state) if state.qk_clip_state is not None else None
295
+ if scales_full is not None:
296
+ ratio = p.shape[0] // scales_full.shape[0]
297
+ idx0 = state.rank_indices[rank][0]
298
+ if isinstance(idx0, slice):
299
+ start = idx0.start or 0
300
+ idx0 = torch.arange(start,
301
+ idx0.stop,
302
+ device=scales_full.device)
303
+ row_scales = scales_full[idx0 // ratio]
304
+ p._local_tensor.mul_(row_scales.view(-1, 1))
305
+
306
+
307
+ # ======================================================================
308
+ # Main generator – thin orchestrator that wires stages together.
309
+ # ======================================================================
310
+
311
+
312
+ @torch.no_grad()
313
+ def muon_chunk_pipeline(
314
+ params: list[DTensor],
315
+ param_to_state: dict[int, _muon_state],
316
+ rank: int,
317
+ ns_steps: int,
318
+ lr: float,
319
+ weight_decay: float,
320
+ none_grad: bool,
321
+ ) -> Generator[None, None, None]:
322
+ """Process one chunk of parameters through the full Muon pipeline.
323
+
324
+ Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
325
+
326
+ Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
327
+ that communication and computation overlap across chunks. Async
328
+ communication is launched via ``async_op=True`` and completed after
329
+ the yield with ``work.wait()``.
330
+
331
+ Overlap happens because :func:`run_pipeline` admits one new chunk
332
+ per iteration (staggered admission). While chunk *N* does NS
333
+ compute on the default CUDA stream, chunk *N+1*'s async all-to-all
334
+ runs concurrently on the NCCL stream — no separate ``comm_stream``
335
+ is required.
336
+
337
+ Yields exactly **2** times:
338
+
339
+ 1. After launching async all-to-all gather.
340
+ 2. After launching async all-to-all scatter.
341
+ """
342
+ process_group = param_to_state[id(params[0])].process_group
343
+ num_ranks = dist.get_world_size(group=process_group)
344
+ owned_params = [
345
+ p for p in params if param_to_state[id(p)].worker_rank == rank
346
+ ]
347
+
348
+ # Stages 1-2: launch async gather.
349
+ with record_function("muon::launch_gather"):
350
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
+ params, owned_params, param_to_state, rank, num_ranks,
352
+ process_group)
353
+
354
+ if none_grad:
355
+ for p in params:
356
+ p.grad = None
357
+
358
+ yield # --- YIELD 1: other chunks can launch their gather ---
359
+
360
+ with record_function("muon::wait_gather"):
361
+ work.wait()
362
+ _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
363
+ param_to_state, rank)
364
+ del recv_buf
365
+
366
+ # Stage 3: Newton-Schulz orthogonalization.
367
+ with record_function("muon::newton_schulz"):
368
+ computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
369
+ gathered_grads.clear()
370
+
371
+ # Stages 4-5: launch async scatter.
372
+ with record_function("muon::launch_scatter"):
373
+ work, recv_buf, scattered_us, recv_counts = _launch_scatter(
374
+ params, owned_params, param_to_state, rank, num_ranks,
375
+ process_group, computed_us)
376
+ computed_us.clear()
377
+
378
+ yield # --- YIELD 2: other chunks can launch their scatter ---
379
+
380
+ with record_function("muon::wait_scatter"):
381
+ work.wait()
382
+ _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
383
+ scattered_us)
384
+ del recv_buf
385
+
386
+ # Stage 6: apply parameter updates.
387
+ with record_function("muon::update_params"):
388
+ _update_params(params, param_to_state, rank, scattered_us, lr,
389
+ weight_decay)
390
+ scattered_us.clear()
build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
12
+ """
13
+ Parse a parameter name to check if it is a query/key projection layer
14
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
15
+
16
+ Returns:
17
+ (kind, layer_idx) or (None, -1) if not matched.
18
+
19
+ Example:
20
+ 'model.3.attn.wq.weight' -> ('wq', 3)
21
+ 'model.5.attn.wk.weight' -> ('wk', 5)
22
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
23
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
+ 'model.4.attn.v_proj.weight' -> (None, -1)
25
+ """
26
+ parts = name.split('.')
27
+ if len(parts) < 3:
28
+ return None, -1
29
+
30
+ kind = parts[-2]
31
+
32
+ layer_idx = -1
33
+ for part in reversed(parts):
34
+ if part.isdigit():
35
+ layer_idx = int(part)
36
+ break
37
+
38
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
39
+ return kind, layer_idx
40
+
41
+ return None, -1
42
+
43
+
44
+ @dataclass
45
+ class QKClipInfo:
46
+ """Per-parameter dynamic info computed from config + runtime logits."""
47
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
48
+ indices: list[int] # which heads to consider for clipping
49
+ head_dim: int # from config
50
+ threshold: float # from config
51
+ logit: torch.Tensor | None
52
+
53
+
54
+ def get_qk_clip_info(clip_config, n, qk_logits):
55
+ """Extract QK clipping info for a named parameter.
56
+
57
+ Args:
58
+ clip_config: QK clipping configuration dict (or None).
59
+ n: Parameter name string.
60
+ qk_logits: Dict mapping layer indices to logit tensors (or None).
61
+
62
+ Returns:
63
+ QKClipInfo instance with clipping configuration for this parameter.
64
+ """
65
+ if clip_config is None:
66
+ return None
67
+
68
+ head_dim = clip_config.get('head_dim')
69
+ threshold = clip_config.get('threshold')
70
+ kind, layer_idx = parse_qk_layer(n)
71
+
72
+ logit, indices = None, []
73
+ if qk_logits is not None and kind is not None:
74
+ logit = qk_logits[layer_idx]
75
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
76
+ indices = clip_config.get(indices_key, []) or []
77
+
78
+ if isinstance(logit, DTensor):
79
+ # In TP settings, qk_logits may be DTensor
80
+ # We convert it to full tensor here for simplicity
81
+ logit = logit.full_tensor()
82
+
83
+ return QKClipInfo(
84
+ kind=kind,
85
+ indices=indices,
86
+ head_dim=head_dim,
87
+ threshold=threshold,
88
+ logit=logit,
89
+ )
90
+
91
+
92
+ def compute_scales(p, qk_clip_state):
93
+ """Compute per-head scaling factors for QK clipping.
94
+
95
+ Returns scales tensor if any head exceeds threshold, else None.
96
+ """
97
+ kind = qk_clip_state.kind
98
+ indices = qk_clip_state.indices
99
+ head_dim = qk_clip_state.head_dim
100
+ threshold = qk_clip_state.threshold
101
+ logit = qk_clip_state.logit
102
+
103
+ H_global = p.shape[0] // head_dim
104
+ scales_full = torch.ones(H_global, device=p.data.device)
105
+ scaling = 0
106
+
107
+ for logit_idx, head_idx in enumerate(indices):
108
+ v_ele = float(logit[logit_idx])
109
+ if v_ele > threshold:
110
+ new_scale = math.sqrt(threshold / v_ele)
111
+ if new_scale < scales_full[head_idx]:
112
+ scales_full[head_idx] = new_scale
113
+ logger.info(
114
+ f"[{kind}] Head {head_idx} exceeded threshold "
115
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
+ )
117
+ scaling += 1
118
+
119
+ return scales_full if scaling > 0 else None
120
+
121
+
122
+ def qk_clip(p, scales, head_dim):
123
+ """Apply per-head scaling to a Q/K projection weight matrix."""
124
+ if isinstance(p, torch.nn.Parameter):
125
+ W = p.data.view(-1, head_dim, p.data.shape[1])
126
+ W.mul_(scales.view(-1, 1, 1))
127
+ else:
128
+ W = p.view(-1, head_dim, p.shape[1])
129
+ W.mul_(scales.view(-1, 1, 1))
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_7aef62f_dirty
3
+ ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_7aef62f_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2
3
  size 2004728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9c7bb12bc030d4959e880a959b39ea07eb03e16175d7cf03829f9860f52525d
3
  size 2004728
build/torch210-cxx11-cu130-x86_64-linux/adamw.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch.distributed.tensor import DTensor
6
+
7
+
8
+ def fused_adamw(
9
+ params: list[torch.Tensor],
10
+ grads: list[torch.Tensor],
11
+ exp_avgs: list[torch.Tensor],
12
+ exp_avg_sqs: list[torch.Tensor],
13
+ max_exp_avg_sqs: list[torch.Tensor],
14
+ state_steps: list[torch.Tensor],
15
+ amsgrad: bool,
16
+ beta1: float,
17
+ beta2: float,
18
+ lr: float | torch.Tensor,
19
+ weight_decay: float,
20
+ eps: float,
21
+ maximize: bool,
22
+ ) -> None:
23
+ if not params:
24
+ return
25
+
26
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
27
+ # treating it as a scalar.
28
+ lr_dict: dict | None = ({
29
+ lr.device: lr
30
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
31
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
32
+ [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
33
+ state_steps] # type: ignore[list-item]
34
+ )
35
+ for (device, _), (
36
+ (
37
+ device_params_,
38
+ device_grads_,
39
+ device_exp_avgs_,
40
+ device_exp_avg_sqs_,
41
+ device_max_exp_avg_sqs,
42
+ device_state_steps_,
43
+ ),
44
+ _,
45
+ ) in grouped_tensors.items():
46
+ device_params = cast(list[torch.Tensor], device_params_)
47
+ device_grads = cast(list[torch.Tensor], device_grads_)
48
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
49
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
50
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
51
+
52
+ if lr_dict is not None and device not in lr_dict:
53
+ lr_dict[device] = lr.to(
54
+ device=device, non_blocking=True) # type: ignore[union-attr]
55
+ lr = lr_dict[device]
56
+ torch._foreach_add_(device_state_steps, 1)
57
+ func = torch._fused_adamw_
58
+ func(
59
+ device_params,
60
+ device_grads,
61
+ device_exp_avgs,
62
+ device_exp_avg_sqs,
63
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
64
+ device_state_steps,
65
+ amsgrad=amsgrad,
66
+ lr=lr, # type: ignore[arg-type]
67
+ beta1=beta1,
68
+ beta2=beta2,
69
+ weight_decay=weight_decay,
70
+ eps=eps,
71
+ maximize=maximize,
72
+ )
73
+
74
+
75
+ def step_adamw_params(optimizer_state, params, group):
76
+ """Run fused AdamW on a list of parameters sharing the same placement.
77
+
78
+ Args:
79
+ optimizer_state: The optimizer's state dict (self.state in Muon).
80
+ params: List of parameters to update.
81
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
+ """
83
+ params_with_grads = []
84
+ grads = []
85
+ moment1 = []
86
+ moment2 = []
87
+ max_exp_avg_sqs = []
88
+ state_steps = []
89
+ lr = group["lr"]
90
+ beta1, beta2 = group["adamw_betas"]
91
+ eps = group["adamw_eps"]
92
+ weight_decay = group["weight_decay"]
93
+
94
+ for p in params:
95
+ g = p.grad
96
+ if g is None:
97
+ continue
98
+ state = optimizer_state[p]
99
+ params_with_grads.append(p)
100
+ grads.append(g)
101
+ if "step" not in state:
102
+ state["step"] = (torch.zeros((),
103
+ dtype=torch.float32,
104
+ device=p.device))
105
+ state["moment1"] = torch.zeros_like(g)
106
+ state["moment2"] = torch.zeros_like(g)
107
+ moment1.append(state["moment1"])
108
+ moment2.append(state["moment2"])
109
+ if not isinstance(state["step"], torch.Tensor):
110
+ step_tensor = torch.tensor(state["step"],
111
+ dtype=torch.float32,
112
+ device=p.device)
113
+ else:
114
+ step_tensor = state["step"]
115
+ state_steps.append(step_tensor)
116
+
117
+ fused_adamw(
118
+ params_with_grads,
119
+ grads,
120
+ moment1,
121
+ moment2,
122
+ max_exp_avg_sqs,
123
+ state_steps,
124
+ amsgrad=False,
125
+ beta1=beta1,
126
+ beta2=beta2,
127
+ lr=lr,
128
+ weight_decay=weight_decay,
129
+ eps=eps,
130
+ maximize=False,
131
+ )
132
+
133
+
134
+ def step_adamw(optimizer_state, group):
135
+ """Dispatch AdamW step, grouping parameters by type and placement.
136
+
137
+ Args:
138
+ optimizer_state: The optimizer's state dict (self.state in Muon).
139
+ group: Parameter group dict.
140
+ """
141
+ params = group["params"]
142
+
143
+ # group params with its type and placement
144
+ placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
+ for p in params:
146
+ match p:
147
+ case DTensor():
148
+ placement_to_params[tuple([p.placements,
149
+ p.device_mesh])].append(p)
150
+ case torch.Tensor():
151
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
152
+
153
+ for group_params in placement_to_params.values():
154
+ step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu130-x86_64-linux/async_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class _Task:
8
+ """Internal: wraps a generator, advances one yield at a time."""
9
+
10
+ def __init__(self, generator: Generator[None, None, None], index: int):
11
+ self._generator = generator
12
+ self._index = index
13
+ self._steps_completed = 0
14
+ self.step() # run to first yield
15
+
16
+ def step(self) -> bool:
17
+ try:
18
+ next(self._generator)
19
+ self._steps_completed += 1
20
+ logger.debug("pipeline[%d] completed stage %d", self._index,
21
+ self._steps_completed)
22
+ return True
23
+ except StopIteration:
24
+ logger.debug("pipeline[%d] finished after %d stages", self._index,
25
+ self._steps_completed)
26
+ return False
27
+
28
+ def close(self):
29
+ self._generator.close()
30
+
31
+
32
+ def run_pipeline(
33
+ pipelines: Generator[Generator[None, None, None], None, None],
34
+ max_concurrent: int,
35
+ ) -> None:
36
+ """Run generator-based pipelines with bounded concurrency.
37
+
38
+ Each pipeline is a generator that yields at stage boundaries.
39
+ The runtime interleaves pipelines so communication and computation
40
+ overlap across chunks.
41
+ """
42
+ if max_concurrent <= 0:
43
+ raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
+
45
+ have_new = True
46
+ task_index = 0
47
+ previous_tasks: list[_Task] = []
48
+
49
+ try:
50
+ while have_new or previous_tasks:
51
+ running_tasks: list[_Task] = []
52
+
53
+ # Admit one new pipeline per iteration (staggered admission).
54
+ # Admitting one at a time ensures that while chunk N does NS
55
+ # compute on the default stream, chunk N+1's NCCL all-to-all
56
+ # runs concurrently on the NCCL stream — creating real
57
+ # communication/computation overlap on the GPU.
58
+ if have_new and len(previous_tasks) < max_concurrent:
59
+ try:
60
+ gen = next(pipelines)
61
+ task = _Task(gen, task_index)
62
+ task_index += 1
63
+ running_tasks.append(task)
64
+ except StopIteration:
65
+ have_new = False
66
+
67
+ # Advance every previously-yielded task by one step.
68
+ for task in previous_tasks:
69
+ if task.step():
70
+ running_tasks.append(task)
71
+
72
+ previous_tasks = running_tasks
73
+ except BaseException:
74
+ # Clean up all in-flight generators to release GPU resources.
75
+ for task in previous_tasks:
76
+ task.close()
77
+ raise
build/torch210-cxx11-cu130-x86_64-linux/core.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed import ProcessGroup
7
+ from torch.distributed.tensor import DTensor
8
+
9
+
10
+ @dataclass
11
+ class _muon_state:
12
+ worker_rank: int
13
+ process_group: ProcessGroup
14
+ rank_indices: dict[int, tuple] # local_rank -> per-dim indices
15
+ rank_numels: dict[int, int] # local_rank -> numel
16
+ name: str
17
+ qk_clip_state: torch.Tensor | None = None
18
+
19
+
20
+ def update_g(optimizer_state, p, g, group, momentum):
21
+ """Apply momentum update to gradient.
22
+
23
+ Args:
24
+ optimizer_state: The optimizer's state dict (self.state in Muon).
25
+ p: Parameter tensor.
26
+ g: Gradient tensor.
27
+ group: Parameter group dict.
28
+ momentum: Momentum coefficient.
29
+
30
+ Returns:
31
+ Momentum-updated gradient tensor.
32
+ """
33
+ state = optimizer_state[p]
34
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
+ torch.add(g, buf, alpha=momentum, out=buf)
36
+ if group["nesterov"]:
37
+ g.add_(buf, alpha=momentum)
38
+ return g
39
+ return buf
40
+
41
+
42
+ def update_p(p, u, lr, adjusted_lr, weight_decay):
43
+ """Apply weight decay and orthogonalized update to parameter.
44
+
45
+ Args:
46
+ p: Parameter (torch.nn.Parameter or DTensor).
47
+ u: Orthogonalized update tensor.
48
+ lr: Base learning rate.
49
+ adjusted_lr: Size-adjusted learning rate.
50
+ weight_decay: Weight decay coefficient.
51
+ """
52
+ if isinstance(p, torch.nn.Parameter):
53
+ # apply weight decay
54
+ p.data.mul_(1 - lr * weight_decay)
55
+ # apply update
56
+ p.data.add_(u, alpha=-adjusted_lr)
57
+ else:
58
+ p.mul_(1 - lr * weight_decay)
59
+ p.add_(u, alpha=-adjusted_lr)
60
+
61
+
62
+ def adjust_lr_for_muon(lr, param_shape):
63
+ """Scale learning rate based on parameter matrix dimensions.
64
+
65
+ Args:
66
+ lr: Base learning rate.
67
+ param_shape: Shape of the parameter tensor.
68
+
69
+ Returns:
70
+ Adjusted learning rate.
71
+ """
72
+ A, B = param_shape[:2]
73
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
74
+ # as described in the paper
75
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
76
+ adjusted_lr = lr * adjusted_ratio
77
+ return adjusted_lr
78
+
79
+
80
+ def default_is_muon(name, x, expert_keys=None):
81
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
+ if any(key in name for key in skip_keys):
83
+ return False
84
+ effective_ndim = x.ndim
85
+ if expert_keys and any(key in name for key in expert_keys):
86
+ effective_ndim -= 1
87
+ return effective_ndim >= 2
88
+
89
+
90
+ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
91
+ if is_muon_func is None:
92
+ is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
+
94
+ muon_params, muon_names = [], []
95
+ non_muon_params = []
96
+
97
+ for n, p in model.named_parameters():
98
+ if not p.requires_grad:
99
+ continue
100
+ if is_muon_func(n, p):
101
+ muon_params.append(p)
102
+ muon_names.append(n)
103
+ else:
104
+ non_muon_params.append(p)
105
+
106
+ return [
107
+ {
108
+ "params": muon_params,
109
+ "names": muon_names,
110
+ "use_muon": True,
111
+ },
112
+ {
113
+ "params": non_muon_params,
114
+ "use_muon": False,
115
+ },
116
+ ]
build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py CHANGED
@@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
  _StridedShard)
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  def get_slices_of_dtensor(
11
  target: DTensor | torch.Tensor,
12
  local_rank: int,
13
  shard_mesh: DeviceMesh,
14
  shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
  """
17
- Get the slice of local tensor for a given rank from a tensor.
 
 
 
 
 
18
  Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
  shard_placements (tuple[Placement]): The shard placements.
23
- """
24
 
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
 
 
 
 
26
 
27
  # find the global rank of the local rank in the shard mesh
28
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
@@ -34,34 +52,75 @@ def get_slices_of_dtensor(
34
 
35
  assert len(rank_coords) == len(shard_placements)
36
 
 
 
 
 
37
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
  # left-to-right sharding. This is ensured by the sorting logic of
39
  # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
 
44
- num_ranks = shard_mesh.mesh.shape[i]
 
45
 
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
 
 
 
48
 
49
- if dim_size % num_ranks != 0:
50
  raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- return tuple(slices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
@@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
71
  def construct_shard_mesh(
72
  placements: tuple[Placement],
73
  mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
 
81
- assert mesh.mesh.device.type == 'cpu'
 
 
82
 
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
 
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
 
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
 
121
 
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
 
 
124
 
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
 
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
  else:
137
  shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
  assert len(shard_placements) == len(set(shard_placements))
142
 
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
  )
171
 
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
7
  _StridedShard)
8
 
9
 
10
+ def _is_shard(placement: Placement) -> bool:
11
+ """Check if a placement is a shard type (Shard or _StridedShard).
12
+
13
+ In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
+ ``placement.is_shard()`` returns False for _StridedShard. This helper
15
+ handles both old and new hierarchies.
16
+ """
17
+ return isinstance(placement, (Shard, _StridedShard))
18
+
19
+
20
  def get_slices_of_dtensor(
21
  target: DTensor | torch.Tensor,
22
  local_rank: int,
23
  shard_mesh: DeviceMesh,
24
  shard_placements: tuple[Placement],
25
+ ) -> tuple[slice | torch.Tensor, ...]:
26
  """
27
+ Get per-dimension indices for a given rank's shard of the target tensor.
28
+
29
+ Uses ``Shard.local_shard_size_and_offset`` and
30
+ ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
+ both contiguous and strided (non-contiguous) sharding.
32
+
33
  Args:
34
+ target (DTensor | torch.Tensor): The target tensor (for its shape).
35
+ local_rank (int): The local rank within the shard group.
36
+ shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
  shard_placements (tuple[Placement]): The shard placements.
 
38
 
39
+ Returns:
40
+ A tuple of indices (one per tensor dim). Each element is either:
41
+ - A ``slice`` (for contiguous or unsharded dims)
42
+ - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
+ """
44
 
45
  # find the global rank of the local rank in the shard mesh
46
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
 
52
 
53
  assert len(rank_coords) == len(shard_placements)
54
 
55
+ # Track per-shard-dim indices.
56
+ # None means "not yet sharded on this dim".
57
+ dim_indices: dict[int, torch.Tensor] = {}
58
+
59
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
  # left-to-right sharding. This is ensured by the sorting logic of
61
  # construct_shard_mesh function.
62
+ for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
+ zip(rank_coords, shard_placements)):
64
+ assert _is_shard(placement)
65
 
66
+ num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
+ shard_dim = placement.dim
68
 
69
+ # Current effective size on this dim (may already be sub-sharded)
70
+ if shard_dim in dim_indices:
71
+ curr_size = len(dim_indices[shard_dim])
72
+ else:
73
+ curr_size = target.size()[shard_dim]
74
 
75
+ if curr_size % num_chunks != 0:
76
  raise NotImplementedError(
77
+ f"Dimension size {curr_size} is not divisible "
78
+ f"by number of ranks {num_chunks} for shard "
79
+ f"placement on dim {shard_dim}. (shape: {target.shape})")
80
+
81
+ # Compute indices for this level of sharding
82
+ if isinstance(placement, _StridedShard):
83
+ _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
84
+ placement,
85
+ curr_size,
86
+ num_chunks,
87
+ rank_coord,
88
+ return_first_offset=False)
89
+ new_indices = torch.tensor(offsets, dtype=torch.long)
90
+ else:
91
+ shard_size, offset = Shard.local_shard_size_and_offset(
92
+ curr_size, num_chunks, rank_coord)
93
+ new_indices = torch.arange(offset,
94
+ offset + shard_size,
95
+ dtype=torch.long)
96
+
97
+ # Compose with previous indices on this dim
98
+ if shard_dim in dim_indices:
99
+ dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
100
+ else:
101
+ dim_indices[shard_dim] = new_indices
102
 
103
+ # Build result tuple
104
+ result: list[slice | torch.Tensor] = []
105
+ for d in range(len(target.size())):
106
+ if d not in dim_indices:
107
+ result.append(slice(None))
108
+ else:
109
+ indices = dim_indices[d]
110
+ # Convert contiguous indices to slice for efficiency
111
+ if len(indices) > 0:
112
+ start = indices[0].item()
113
+ expected = torch.arange(start,
114
+ start + len(indices),
115
+ dtype=torch.long)
116
+ if torch.equal(indices, expected):
117
+ result.append(slice(start, start + len(indices)))
118
+ else:
119
+ result.append(indices)
120
+ else:
121
+ result.append(slice(0, 0))
122
+
123
+ return tuple(result)
124
 
125
 
126
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
 
130
  def construct_shard_mesh(
131
  placements: tuple[Placement],
132
  mesh: DeviceMesh,
133
+ ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
134
+ """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
 
 
 
 
135
 
136
+ Given a DTensor's placements and device mesh, extracts the "shard group"
137
+ — the set of ranks that together hold all shards of the same replica —
138
+ and creates a ProcessGroup for all-to-all among them.
139
 
140
+ Steps:
141
+ 1. Sort placements: Replicate first, then Shard by (dim, granularity).
142
+ 2. Permute the mesh tensor to match the sorted order.
143
+ 3. Collapse Replicate dims list of shard sub-meshes (one per replica).
144
+ 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ Example — 8 GPUs, mesh shape (2, 2, 2),
147
+ placements ``[Shard(0), Replicate, _StridedShard(0)]``::
 
 
148
 
149
+ Step 1 Sort: [Replicate, _StridedShard(0), Shard(0)]
150
+ Permutation: [1, 2, 0]
151
 
152
+ Step 2 Permute mesh dims by [1, 2, 0]:
153
+ Original: Permuted:
154
+ [[[0,1],[2,3]], [[[0,2],[1,3]],
155
+ [[4,5],[6,7]]] [[4,6],[5,7]]]
156
 
157
+ Step 3 Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
158
+ sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
159
+ sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
160
+ shard_placements = (_StridedShard(0), Shard(0))
161
 
162
+ Step 4 Rank 0 → ProcessGroup([0,1,4,5])
163
+ Rank 2 ProcessGroup([2,3,6,7])
164
+
165
+ Returns:
166
+ ``(shard_mesh, process_group, shard_placements)``
167
+ """
168
+ my_rank = dist.get_rank()
169
+ assert mesh.mesh.device.type == 'cpu'
170
+
171
+ # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
172
+ # This avoids a non-collective dist.new_group() call, which would
173
+ # deadlock when only a subset of ranks call this function (e.g. expert
174
+ # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
175
+ if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
176
+ key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
177
+ if key not in _ranks_to_dist_cache:
178
+ _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
179
+ return (*_ranks_to_dist_cache[key], tuple(placements))
180
+
181
+ mesh_tensor = mesh.mesh.clone()
182
+
183
+ # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
184
+ # _StridedShard comes BEFORE regular Shard on the same dim so that
185
+ # get_slices_of_dtensor applies the outer sharding first, matching
186
+ # DTensor's left-to-right (outer-to-inner) composition order.
187
+ def _sort_key(item):
188
+ index, placement = item
189
+ assert not placement.is_partial(), "Partial placement not supported"
190
+ if placement.is_replicate():
191
+ return (-1, 0, index)
192
+ assert _is_shard(placement), f"Unsupported: {type(placement)}"
193
+ split = (-1 / placement.split_factor if isinstance(
194
+ placement, _StridedShard) else 0)
195
+ return (placement.dim, split, index)
196
+
197
+ indexed = sorted(enumerate(placements), key=_sort_key)
198
+ perm, sorted_placements = zip(*indexed)
199
+
200
+ # -- Step 2: Permute mesh to match sorted placement order. --------------
201
+ sorted_mesh = mesh_tensor.permute(perm)
202
+
203
+ # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
204
+ # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
205
+ num_rep = sum(1 for p in sorted_placements if p.is_replicate())
206
+ if num_rep > 0:
207
+ if num_rep > 1:
208
+ sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
209
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
210
  else:
211
  shard_meshes = [sorted_mesh]
212
+ shard_placements = sorted_placements[num_rep:]
 
 
213
  assert len(shard_placements) == len(set(shard_placements))
214
 
215
+ # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
216
+ # All ranks must call dist.new_group in the same order, even though each
217
+ # rank only joins one group.
218
+ def _cache_key(t: torch.Tensor) -> tuple:
219
+ return (*t.shape, *t.flatten().tolist())
220
+
221
+ my_key = None
222
+ for sm in shard_meshes:
223
+ key = _cache_key(sm)
224
+ if (my_rank == sm).any().item():
225
+ assert my_key is None, "Rank appears in multiple shard groups"
226
+ my_key = key
227
+ if key not in _ranks_to_dist_cache:
228
+ pg = dist.new_group(sm.flatten().tolist())
229
+ _ranks_to_dist_cache[key] = (
230
+ DeviceMesh(device_type="cuda", mesh=sm),
231
+ pg,
 
 
 
 
 
 
 
 
 
 
232
  )
233
 
234
+ return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/metadata.json CHANGED
@@ -1 +1,3 @@
1
- {"python-depends":[]}
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu130-x86_64-linux/muon.py CHANGED
@@ -1,536 +1,121 @@
1
  import logging
2
- import math
3
  import types
4
  from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
 
 
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
 
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
 
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
 
 
406
 
407
- assert inner_off == block
408
- off += block
409
 
 
410
 
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
  """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
  continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
 
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- return None, -1
 
524
 
 
525
 
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
@@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer):
554
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
  weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
  adamw_lr: The learning rate for the internal AdamW.
559
  adamw_betas: The betas for the internal AdamW.
560
  adamw_eps: The epsilon for the internal AdamW.
@@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer):
564
  - "q_indices" (list[int]): Indices of query heads to consider.
565
  - "k_indices" (list[int]): Indices of key heads to consider.
566
  - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
  this value will be scaled down.
569
  Default is:
570
  {
@@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer):
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
 
 
 
 
 
 
 
587
  """
588
 
589
  def __init__(self,
@@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer):
597
  adamw_eps=1e-8,
598
  none_grad=True,
599
  debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
  warmup_step=5,
607
  chunk_size=-1,
608
  use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
 
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
@@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer):
630
 
631
  super().__init__(params, defaults)
632
 
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
  self.debug = debug
638
- self.clip_config = clip_config
 
 
 
 
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
  self.small_param_numel_threshold = small_param_numel_threshold
 
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
@@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer):
649
 
650
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
 
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
  def get_shard_mesh(self, p):
667
  """
668
  Get the shard mesh for a parameter p on the given rank.
@@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer):
673
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
  p.placements, p.device_mesh)
675
 
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
  return shard_mesh, shard_pg, shard_placements
680
 
681
  def init_state_and_assign_params(self, names, params, group, qk_logits):
@@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer):
694
  total_flops += flops
695
 
696
  if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
 
700
  paired = list(zip(names, params))
701
 
@@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer):
724
 
725
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  param_to_state[id(p)] = _muon_state(
730
  worker_rank=worker_rank,
731
  process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
  name=n,
735
  qk_clip_state=qk_clip_state,
736
  )
737
 
738
  return param_to_state, ordered_params
739
 
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
  for n, p in zip(names, params):
744
  g = p.grad
745
  if g is None:
746
  continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
755
 
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
 
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
 
761
- scales_full = self._compute_scales(
762
  p, qk_clip_state) if qk_clip_state is not None else None
763
  if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
 
766
  def distributed_muon(
767
  self,
@@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer):
770
  group: dict[str, Any],
771
  lr: float,
772
  weight_decay: float,
773
- momentum: float,
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
777
 
 
778
  for n, p in zip(names, params):
779
  g = p.grad
780
  if g is None:
781
  continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
@@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer):
796
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
  steps=group["ns_steps"])
798
 
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
 
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
 
804
- scales_full = self._compute_scales(
805
  p_full, qk_clip_state) if qk_clip_state is not None else None
806
 
807
  if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
  ndims = len(p.device_mesh.mesh.shape)
@@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer):
822
 
823
  p.copy_(p_sharded)
824
 
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
  """
911
  Perform a parallel optimization step using Muon.
912
- """
913
 
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
 
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
 
930
  param_to_state, ordered_params = self.init_state_and_assign_params(
931
  names, params, group, qk_logits)
932
 
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
 
967
  if self.chunk_size == -1:
968
  shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
  elif self.chunk_size > 0:
972
  chunk_size = self.chunk_size
973
  else:
974
  raise ValueError("chunk_size must be -1 or a positive integer.")
975
 
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
 
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
 
1064
  def _step_muon(self, group, qk_logits=None):
1065
  params = group["params"]
@@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer):
1068
  momentum = group["momentum"]
1069
  names = group["names"]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  param_dtensors = []
1072
  name_dtensors = []
1073
 
@@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer):
1083
  group=group,
1084
  lr=lr,
1085
  weight_decay=weight_decay,
1086
- momentum=momentum,
1087
  qk_logits=qk_logits)
1088
  return
1089
 
@@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer):
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
  assert len(dtensors) == len(names)
1125
  for p, n in zip(dtensors, names):
@@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer):
1141
  group=group,
1142
  lr=lr,
1143
  weight_decay=weight_decay,
1144
- momentum=momentum,
1145
  qk_logits=qk_logits,
1146
  )
1147
 
@@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer):
1159
  group,
1160
  lr=lr,
1161
  weight_decay=weight_decay,
1162
- momentum=momentum,
1163
  qk_logits=qk_logits,
1164
  )
1165
 
@@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer):
1170
  group,
1171
  lr=lr,
1172
  weight_decay=weight_decay,
1173
- momentum=momentum,
1174
  qk_logits=qk_logits,
1175
  )
1176
 
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
  @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
@@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer):
1249
  Args:
1250
  closure (Callable, optional): A closure that reevaluates the model
1251
  and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
  (1 / sqrt(head_dim)) * (Q @ K^T).
1256
  """
1257
  loss = None
@@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer):
1263
  if group["use_muon"]:
1264
  self._step_muon(group, qk_logits=qk_logits)
1265
  else:
1266
- self._step_adamw(group)
1267
 
1268
  return loss
 
1
  import logging
 
2
  import types
3
  from collections import defaultdict
4
+ from typing import Any
 
5
 
6
  import torch
7
  import torch.distributed as dist
8
+ from torch.distributed.tensor import DTensor, Replicate, Shard
9
+ from torch.profiler import record_function
10
+
11
+ from .adamw import step_adamw
12
+ from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, update_g, update_p)
15
+ from .distributed.utils import (_is_shard, construct_shard_mesh,
16
+ get_slices_of_dtensor)
17
+ from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
+ _zeropower_via_newtonschulz5)
19
+ from .pipeline import muon_chunk_pipeline
20
+ from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def _expand_expert_params(names, params, expert_keys):
26
+ """Expand expert params by splitting on dim 0 (expert dimension).
 
27
 
28
+ Params whose name matches any key in ``expert_keys`` are treated as
29
+ expert-parallel tensors. Their outermost dimension is the expert
30
+ dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
31
+ ``nn.Parameter`` views so that in-place updates propagate back to
32
+ the original storage.
33
 
34
+ Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
35
+ if they are expert params, their key must be added to ``expert_keys``.
36
 
37
+ The grad must already be set on each expert param (e.g. after momentum).
38
 
39
+ For DTensor expert params, placements that shard on dim 0 (expert dim)
40
+ are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
41
+ preserved: each 2D slice is wrapped as a DTensor on the corresponding
42
+ submesh so the parallel pipeline handles the TP communication.
 
43
  """
44
+ expanded_names = []
45
+ expanded_params = []
46
+
47
+ for n, p in zip(names, params):
48
+ is_expert = expert_keys and any(key in n for key in expert_keys)
49
+ is_dtensor = isinstance(p.data, DTensor)
50
+
51
+ if not is_expert:
52
+ assert p.data.ndim <= 2, (
53
+ f"Param {n} has ndim={p.data.ndim} but does not match "
54
+ f"expert_keys={expert_keys}. If this is an expert param, "
55
+ f"add its key to expert_keys.")
56
+ expanded_names.append(n)
57
+ expanded_params.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ g = p.grad
61
+ assert g is not None, (
62
+ f"Expert param {n} must have grad set before expansion")
63
+
64
+ tp_mesh = None
65
+ tp_placements_2d = None
66
+
67
+ if is_dtensor:
68
+ local_data = p.to_local()
69
+ local_grad = g.to_local() if isinstance(g, DTensor) else g
70
+
71
+ # Find non-dim-0 shard placements (e.g. TP sharding).
72
+ # After splitting on dim 0, Shard(k) becomes Shard(k-1).
73
+ tp_dim_indices = []
74
+ tp_placements_2d = []
75
+ for i, pl in enumerate(p.placements):
76
+ if _is_shard(pl) and pl.dim != 0:
77
+ tp_dim_indices.append(i)
78
+ tp_placements_2d.append(Shard(pl.dim - 1))
79
+
80
+ if tp_dim_indices:
81
+ tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
82
+ for i in tp_dim_indices)
83
+ if len(tp_dim_names) == 1:
84
+ tp_mesh = p.device_mesh[tp_dim_names[0]]
85
+ else:
86
+ tp_mesh = p.device_mesh[tp_dim_names]
87
+ else:
88
+ local_data = p.data
89
+ local_grad = g
90
+
91
+ # Expand: split dim 0, reshape each slice to 2D.
92
+ num_local_experts = local_data.shape[0]
93
+ for i in range(num_local_experts):
94
+ slice_data = local_data[i]
95
+ slice_grad = local_grad[i]
96
+
97
+ if tp_mesh is not None:
98
+ # Wrap as DTensor on TP submesh so the pipeline handles
99
+ # TP communication (gather/scatter across TP ranks).
100
+ dt_data = DTensor.from_local(slice_data,
101
+ device_mesh=tp_mesh,
102
+ placements=tp_placements_2d)
103
+ dt_grad = DTensor.from_local(slice_grad,
104
+ device_mesh=tp_mesh,
105
+ placements=tp_placements_2d)
106
+ expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
107
+ expert_param.grad = dt_grad
108
+ else:
109
+ expert_param = torch.nn.Parameter(slice_data,
110
+ requires_grad=False)
111
+ expert_param.grad = slice_grad
112
 
113
+ expanded_names.append(f"{n}[{i}]")
114
+ expanded_params.append(expert_param)
115
 
116
+ p.grad = None # allow expert grad storage to be freed after pipeline
117
 
118
+ return expanded_names, expanded_params
 
 
 
 
 
 
 
119
 
120
 
121
  class Muon(torch.optim.Optimizer):
 
139
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
140
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
141
  weight_decay: The weight decay for Muon and AdamW.
142
+ Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
143
  adamw_lr: The learning rate for the internal AdamW.
144
  adamw_betas: The betas for the internal AdamW.
145
  adamw_eps: The epsilon for the internal AdamW.
 
149
  - "q_indices" (list[int]): Indices of query heads to consider.
150
  - "k_indices" (list[int]): Indices of key heads to consider.
151
  - "head_dim" (int): Dimensionality of each attention head.
152
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
153
  this value will be scaled down.
154
  Default is:
155
  {
 
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
+ expert_keys: List of strings to identify expert-parallel parameters.
173
+ If any key appears in a parameter's name, its outermost
174
+ dimension is treated as the expert dimension and expanded
175
+ into per-expert 2D params for Muon. For example,
176
+ ``expert_keys=["experts"]`` matches any param whose name
177
+ contains "experts". 3D+ params not matched by any key
178
+ will raise an error.
179
  """
180
 
181
  def __init__(self,
 
189
  adamw_eps=1e-8,
190
  none_grad=True,
191
  debug=False,
192
+ clip_config=None,
 
 
 
 
 
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
+ small_param_numel_threshold=65536,
197
+ expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
 
218
 
219
  super().__init__(params, defaults)
220
 
 
 
 
 
221
  self.debug = debug
222
+ self.clip_config = clip_config if clip_config is not None else {
223
+ "q_indices": [],
224
+ "k_indices": [],
225
+ "head_dim": 128,
226
+ "threshold": 100,
227
+ }
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
  self.small_param_numel_threshold = small_param_numel_threshold
232
+ self.expert_keys = expert_keys
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
 
239
 
240
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def get_shard_mesh(self, p):
243
  """
244
  Get the shard mesh for a parameter p on the given rank.
 
249
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
250
  p.placements, p.device_mesh)
251
 
 
 
 
252
  return shard_mesh, shard_pg, shard_placements
253
 
254
  def init_state_and_assign_params(self, names, params, group, qk_logits):
 
267
  total_flops += flops
268
 
269
  if self.debug:
270
+ logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
271
+ total_flops / 1e12)
272
 
273
  paired = list(zip(names, params))
274
 
 
297
 
298
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
299
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
300
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
301
+
302
+ # Precompute per-rank indices and numels for all-to-all.
303
+ rank_indices: dict[int, tuple] = {}
304
+ rank_numels: dict[int, int] = {}
305
+ for r in range(num_ranks):
306
+ indices = get_slices_of_dtensor(p, r, shard_mesh,
307
+ shard_placements)
308
+ rank_indices[r] = indices
309
+ numel = 1
310
+ for idx, dim_size in zip(indices, p.shape):
311
+ if isinstance(idx, slice):
312
+ start, stop, step = idx.indices(dim_size)
313
+ numel *= max(0, (stop - start + (step - 1)) // step)
314
+ else:
315
+ numel *= len(idx)
316
+ rank_numels[r] = numel
317
 
318
  param_to_state[id(p)] = _muon_state(
319
  worker_rank=worker_rank,
320
  process_group=shard_pg,
321
+ rank_indices=rank_indices,
322
+ rank_numels=rank_numels,
323
  name=n,
324
  qk_clip_state=qk_clip_state,
325
  )
326
 
327
  return param_to_state, ordered_params
328
 
329
+ def base(self, names, params, group, lr, weight_decay, qk_logits):
330
+ # Momentum is already applied by _step_muon before this method.
 
331
  for n, p in zip(names, params):
332
  g = p.grad
333
  if g is None:
334
  continue
 
 
 
 
 
335
 
336
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
  steps=group["ns_steps"])
338
 
339
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
+ update_p(p, u, lr, adjusted_lr, weight_decay)
341
 
342
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
343
 
344
+ scales_full = compute_scales(
345
  p, qk_clip_state) if qk_clip_state is not None else None
346
  if scales_full is not None:
347
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
348
 
349
  def distributed_muon(
350
  self,
 
353
  group: dict[str, Any],
354
  lr: float,
355
  weight_decay: float,
 
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
  """ Implementation of Distributed Muon by Liu et al. """
359
 
360
+ # Momentum is already applied by _step_muon before this method.
361
  for n, p in zip(names, params):
362
  g = p.grad
363
  if g is None:
364
  continue
 
 
 
 
 
365
 
366
  # Gather G
367
  if isinstance(p.data, DTensor):
 
374
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
  steps=group["ns_steps"])
376
 
377
+ adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
+ update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
381
 
382
+ scales_full = compute_scales(
383
  p_full, qk_clip_state) if qk_clip_state is not None else None
384
 
385
  if scales_full is not None:
386
+ qk_clip(p_full, scales_full, qk_clip_state.head_dim)
387
 
388
  if isinstance(p.data, DTensor):
389
  ndims = len(p.device_mesh.mesh.shape)
 
400
 
401
  p.copy_(p_sharded)
402
 
403
+ def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
 
406
 
407
+ Parameters are chunked and each chunk is processed by a
408
+ :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
409
+ interleaves multiple chunks so that communication and computation
410
+ overlap across chunks (the same overlap previously achieved by the
411
+ warmup + main-loop index scheduling).
412
+ """
413
 
414
+ # Momentum is already applied by _step_muon before this method.
 
 
 
 
 
 
 
415
 
416
  param_to_state, ordered_params = self.init_state_and_assign_params(
417
  names, params, group, qk_logits)
418
 
419
+ # Compute local rank for this group's shard process group.
420
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
421
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  if self.chunk_size == -1:
424
  shard_ranks = dist.get_world_size(param_to_state[id(
425
+ ordered_params[0])].process_group)
426
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
  elif self.chunk_size > 0:
428
  chunk_size = self.chunk_size
429
  else:
430
  raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
+ def pipelines():
433
+ for start in range(0, len(ordered_params), chunk_size):
434
+ chunk = ordered_params[start:start + chunk_size]
435
+ if chunk:
436
+ yield muon_chunk_pipeline(
437
+ params=chunk,
438
+ param_to_state=param_to_state,
439
+ rank=rank,
440
+ ns_steps=group["ns_steps"],
441
+ lr=lr,
442
+ weight_decay=weight_decay,
443
+ none_grad=group["none_grad"],
444
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ with record_function("muon::barrier"):
447
+ dist.barrier()
448
+ with record_function("muon::pipeline"):
449
+ run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  def _step_muon(self, group, qk_logits=None):
452
  params = group["params"]
 
455
  momentum = group["momentum"]
456
  names = group["names"]
457
 
458
+ # Apply momentum to all params before routing/expansion.
459
+ with record_function("muon::momentum"):
460
+ for n, p in zip(names, params):
461
+ g = p.grad
462
+ if g is None:
463
+ continue
464
+ g = update_g(self.state, p, g, group, momentum)
465
+ p.grad = g
466
+
467
+ # Expand expert params by splitting on dim 0.
468
+ names, params = _expand_expert_params(names, params, self.expert_keys)
469
+
470
  param_dtensors = []
471
  name_dtensors = []
472
 
 
482
  group=group,
483
  lr=lr,
484
  weight_decay=weight_decay,
 
485
  qk_logits=qk_logits)
486
  return
487
 
 
517
  # and run parallel Muon on each group.
518
 
519
  placement_to_params = defaultdict(lambda: ([], []))
 
520
 
521
  assert len(dtensors) == len(names)
522
  for p, n in zip(dtensors, names):
 
538
  group=group,
539
  lr=lr,
540
  weight_decay=weight_decay,
 
541
  qk_logits=qk_logits,
542
  )
543
 
 
555
  group,
556
  lr=lr,
557
  weight_decay=weight_decay,
 
558
  qk_logits=qk_logits,
559
  )
560
 
 
565
  group,
566
  lr=lr,
567
  weight_decay=weight_decay,
 
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
 
575
  Args:
576
  closure (Callable, optional): A closure that reevaluates the model
577
  and returns the loss.
578
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
579
+ to 1D tensors of shape (num_heads,), representing the maximum
580
+ QK logits across all tokens, computed as
581
  (1 / sqrt(head_dim)) * (Q @ K^T).
582
  """
583
  loss = None
 
589
  if group["use_muon"]:
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
592
+ step_adamw(self.state, group)
593
 
594
  return loss
build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .matmul_transpose_triton import matmul_transpose_assign
4
+
5
+ COMM_DTYPE = torch.bfloat16
6
+ DEFAULT_CHUNK_SIZE_RATIO = 4
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ # Muon's Newton–Schulz iteration causes high variance in singular values
12
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
13
+ @torch.no_grad()
14
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
+ def _zeropower_via_newtonschulz5(G, steps):
16
+ """
17
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
+ performance at all relative to UV^T, where USV^T = G is the SVD.
24
+ """
25
+ assert len(G.shape) == 2
26
+ assert G.dtype == COMM_DTYPE
27
+ X = G # no manual typecast
28
+
29
+ if G.size(0) > G.size(1):
30
+ X = X.T
31
+ # Ensure spectral norm is at most 1
32
+ X = X / (X.norm() + 1e-7)
33
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
+ # Perform the NS iterations
36
+ for a, b, c in [
37
+ (4.0848, -6.8946, 2.9270),
38
+ (3.9505, -6.3029, 2.6377),
39
+ (3.7418, -5.5913, 2.3037),
40
+ (2.8769, -3.1427, 1.2046),
41
+ (2.8366, -3.0525, 1.2012),
42
+ ]:
43
+ matmul_transpose_assign(X, buf1)
44
+ matmul_transpose_assign(buf1, buf2)
45
+ buf1.mul_(b).add_(buf2, alpha=c)
46
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
47
+
48
+ if G.size(0) > G.size(1):
49
+ X = X.T
50
+ return X
build/torch210-cxx11-cu130-x86_64-linux/pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ from .core import _muon_state, adjust_lr_for_muon, update_p
10
+ from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
+ from .qk_clip import compute_scales
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ======================================================================
16
+ # Stage helpers
17
+ # ======================================================================
18
+
19
+
20
+ def _launch_gather(
21
+ params: list[DTensor],
22
+ owned_params: list[DTensor],
23
+ param_to_state: dict[int, _muon_state],
24
+ rank: int,
25
+ num_ranks: int,
26
+ process_group: dist.ProcessGroup,
27
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
+ """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
+
30
+ Returns:
31
+ work: Async operation handle.
32
+ recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
+ gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
+ ``None`` for non-owned.
35
+ recv_counts: Per-source-rank element counts.
36
+ """
37
+ # Allocate gathered-grad buffers
38
+ gathered_grads: dict[int, torch.Tensor | None] = {}
39
+ for p in params:
40
+ state = param_to_state[id(p)]
41
+ if rank == state.worker_rank:
42
+ gathered_grads[id(p)] = torch.empty(p.shape,
43
+ dtype=COMM_DTYPE,
44
+ device="cuda")
45
+ else:
46
+ gathered_grads[id(p)] = None
47
+
48
+ # Build send buffer
49
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
+ send_counts = [0] * num_ranks
51
+
52
+ for p in params:
53
+ state = param_to_state[id(p)]
54
+ dst = state.worker_rank
55
+ assert dst < num_ranks
56
+ shard_elems = state.rank_numels[rank]
57
+ g = p.grad
58
+ g = g.to_local().to(COMM_DTYPE).contiguous()
59
+ assert g.numel() == shard_elems
60
+ per_dst[dst].append(g.view(-1))
61
+ send_counts[dst] += shard_elems
62
+
63
+ assert any(
64
+ len(v) > 0 for v in
65
+ per_dst), "At least one destination rank must receive a sharded tensor"
66
+ per_dst_flat = [t for dst in per_dst for t in dst]
67
+ send_buf = torch.cat(per_dst_flat, dim=0)
68
+
69
+ # Build recv buffer
70
+ recv_counts = [0] * num_ranks
71
+ for src in range(num_ranks):
72
+ total = 0
73
+ for p in owned_params:
74
+ state = param_to_state[id(p)]
75
+ assert state.worker_rank == rank
76
+ total += state.rank_numels[src]
77
+ recv_counts[src] = total
78
+
79
+ recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
80
+
81
+ # Launch async all-to-all
82
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
83
+ f"recv_buf size: {recv_buf.numel()}, "
84
+ f"recv_counts: {recv_counts}, "
85
+ f"send_counts: {send_counts}, "
86
+ f"process_group: {str(process_group)}")
87
+ work = dist.all_to_all_single(
88
+ recv_buf,
89
+ send_buf,
90
+ output_split_sizes=recv_counts,
91
+ input_split_sizes=send_counts,
92
+ group=process_group,
93
+ async_op=True,
94
+ )
95
+
96
+ return work, recv_buf, gathered_grads, recv_counts
97
+
98
+
99
+ def _complete_gather(
100
+ recv_buf: torch.Tensor,
101
+ recv_counts: list[int],
102
+ owned_params: list[DTensor],
103
+ gathered_grads: dict[int, torch.Tensor | None],
104
+ param_to_state: dict[int, _muon_state],
105
+ rank: int,
106
+ ) -> None:
107
+ """Reconstruct gathered grads from the recv buffer (in-place)."""
108
+ off = 0
109
+ for src in range(len(recv_counts)):
110
+ if recv_counts[src] == 0:
111
+ continue
112
+
113
+ block = recv_counts[src]
114
+ inner_off = 0
115
+ for p in owned_params:
116
+ state = param_to_state[id(p)]
117
+ assert state.worker_rank == rank
118
+
119
+ indices = state.rank_indices[src]
120
+
121
+ shard_view = gathered_grads[id(p)][indices]
122
+ n = shard_view.numel()
123
+ assert n > 0
124
+
125
+ sg = recv_buf.narrow(0, off + inner_off, n)
126
+ sg = sg.reshape(shard_view.shape)
127
+ gathered_grads[id(p)][indices] = sg
128
+
129
+ inner_off += n
130
+ assert inner_off == block
131
+ off += block
132
+
133
+
134
+ def _compute_ns(
135
+ owned_params: list[DTensor],
136
+ gathered_grads: dict[int, torch.Tensor | None],
137
+ ns_steps: int,
138
+ ) -> dict[int, torch.Tensor | None]:
139
+ """Run Newton-Schulz orthogonalization on owned parameters.
140
+
141
+ Returns:
142
+ computed_us: ``{id(p): orthogonalized_update}`` for owned params.
143
+ """
144
+ computed_us: dict[int, torch.Tensor | None] = {}
145
+ for p in owned_params:
146
+ u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
+ gathered_grads[id(p)] = None # free gathered grad
148
+ computed_us[id(p)] = u
149
+ return computed_us
150
+
151
+
152
+ def _launch_scatter(
153
+ params: list[DTensor],
154
+ owned_params: list[DTensor],
155
+ param_to_state: dict[int, _muon_state],
156
+ rank: int,
157
+ num_ranks: int,
158
+ process_group: dist.ProcessGroup,
159
+ computed_us: dict[int, torch.Tensor | None],
160
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
161
+ """Allocate scatter buffers, build send/recv, and launch async all-to-all.
162
+
163
+ Returns:
164
+ work: Async operation handle.
165
+ recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
+ scattered_us: ``{id(p): empty_local_tensor}`` for all params.
167
+ recv_counts: Per-source-rank element counts.
168
+ """
169
+ # Allocate scattered-u buffers
170
+ scattered_us: dict[int, torch.Tensor] = {}
171
+ for p in params:
172
+ scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
173
+
174
+ # Build send buffer (from computed_us on owner ranks)
175
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
+ send_counts = [0] * num_ranks
177
+
178
+ if owned_params:
179
+ for p in owned_params:
180
+ state = param_to_state[id(p)]
181
+
182
+ assert computed_us[id(p)] is not None
183
+ u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
+
185
+ total_sent = 0
186
+ for dst_rank in range(num_ranks):
187
+ indices = state.rank_indices[dst_rank]
188
+ su = u_full[indices].flatten()
189
+
190
+ n = su.numel()
191
+ assert n > 0
192
+
193
+ per_dst[dst_rank].append(su)
194
+ send_counts[dst_rank] += n
195
+ total_sent += n
196
+
197
+ assert total_sent == u_full.numel()
198
+
199
+ lengths = [len(v) for v in per_dst]
200
+ if all(l > 0 for l in lengths):
201
+ assert all(
202
+ l == lengths[0] for l in lengths
203
+ ), "All destination ranks must have the same number of sharded tensor"
204
+ per_dst_flat = [t for dst in per_dst for t in dst]
205
+ send_buf = torch.cat(per_dst_flat, dim=0)
206
+ else:
207
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
+
209
+ # Build recv buffer
210
+ recv_counts = [0] * num_ranks
211
+ for src in range(num_ranks):
212
+ total = 0
213
+ for p in params:
214
+ state = param_to_state[id(p)]
215
+ if state.worker_rank != src:
216
+ continue
217
+ total += state.rank_numels[rank]
218
+ recv_counts[src] = total
219
+
220
+ recv_total = sum(recv_counts)
221
+ assert recv_total > 0
222
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
+
224
+ # Launch async all-to-all
225
+ work = dist.all_to_all_single(
226
+ recv_buf,
227
+ send_buf,
228
+ output_split_sizes=recv_counts,
229
+ input_split_sizes=send_counts,
230
+ group=process_group,
231
+ async_op=True,
232
+ )
233
+
234
+ return work, recv_buf, scattered_us, recv_counts
235
+
236
+
237
+ def _complete_scatter(
238
+ recv_buf: torch.Tensor,
239
+ recv_counts: list[int],
240
+ params: list[DTensor],
241
+ param_to_state: dict[int, _muon_state],
242
+ rank: int,
243
+ scattered_us: dict[int, torch.Tensor],
244
+ ) -> None:
245
+ """Copy recv buffer into scattered_us (in-place)."""
246
+ off = 0
247
+ for src in range(len(recv_counts)):
248
+ block = recv_counts[src]
249
+ if block == 0:
250
+ continue
251
+
252
+ inner_off = 0
253
+ for p in params:
254
+ state = param_to_state[id(p)]
255
+ if state.worker_rank != src:
256
+ continue
257
+ n = state.rank_numels[rank]
258
+ assert n > 0
259
+
260
+ flat_local = recv_buf.narrow(0, off + inner_off,
261
+ n).view_as(p.to_local())
262
+ scattered_us[id(p)].copy_(flat_local)
263
+
264
+ inner_off += n
265
+
266
+ assert inner_off == block
267
+ off += block
268
+
269
+
270
+ def _update_params(
271
+ params: list[DTensor],
272
+ param_to_state: dict[int, _muon_state],
273
+ rank: int,
274
+ scattered_us: dict[int, torch.Tensor],
275
+ lr: float,
276
+ weight_decay: float,
277
+ ) -> None:
278
+ """Apply weight decay, Muon update, and optional QK clipping."""
279
+ for p in params:
280
+ state = param_to_state[id(p)]
281
+ u_dtensor = DTensor.from_local(
282
+ scattered_us[id(p)],
283
+ placements=p.placements,
284
+ device_mesh=p.device_mesh,
285
+ )
286
+
287
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
+ update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
289
+
290
+ # QK clipping – applied directly on the local tensor to
291
+ # avoid DTensor sharding-propagation issues with _StridedShard.
292
+ scales_full = compute_scales(
293
+ p,
294
+ state.qk_clip_state) if state.qk_clip_state is not None else None
295
+ if scales_full is not None:
296
+ ratio = p.shape[0] // scales_full.shape[0]
297
+ idx0 = state.rank_indices[rank][0]
298
+ if isinstance(idx0, slice):
299
+ start = idx0.start or 0
300
+ idx0 = torch.arange(start,
301
+ idx0.stop,
302
+ device=scales_full.device)
303
+ row_scales = scales_full[idx0 // ratio]
304
+ p._local_tensor.mul_(row_scales.view(-1, 1))
305
+
306
+
307
+ # ======================================================================
308
+ # Main generator – thin orchestrator that wires stages together.
309
+ # ======================================================================
310
+
311
+
312
+ @torch.no_grad()
313
+ def muon_chunk_pipeline(
314
+ params: list[DTensor],
315
+ param_to_state: dict[int, _muon_state],
316
+ rank: int,
317
+ ns_steps: int,
318
+ lr: float,
319
+ weight_decay: float,
320
+ none_grad: bool,
321
+ ) -> Generator[None, None, None]:
322
+ """Process one chunk of parameters through the full Muon pipeline.
323
+
324
+ Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
325
+
326
+ Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
327
+ that communication and computation overlap across chunks. Async
328
+ communication is launched via ``async_op=True`` and completed after
329
+ the yield with ``work.wait()``.
330
+
331
+ Overlap happens because :func:`run_pipeline` admits one new chunk
332
+ per iteration (staggered admission). While chunk *N* does NS
333
+ compute on the default CUDA stream, chunk *N+1*'s async all-to-all
334
+ runs concurrently on the NCCL stream — no separate ``comm_stream``
335
+ is required.
336
+
337
+ Yields exactly **2** times:
338
+
339
+ 1. After launching async all-to-all gather.
340
+ 2. After launching async all-to-all scatter.
341
+ """
342
+ process_group = param_to_state[id(params[0])].process_group
343
+ num_ranks = dist.get_world_size(group=process_group)
344
+ owned_params = [
345
+ p for p in params if param_to_state[id(p)].worker_rank == rank
346
+ ]
347
+
348
+ # Stages 1-2: launch async gather.
349
+ with record_function("muon::launch_gather"):
350
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
+ params, owned_params, param_to_state, rank, num_ranks,
352
+ process_group)
353
+
354
+ if none_grad:
355
+ for p in params:
356
+ p.grad = None
357
+
358
+ yield # --- YIELD 1: other chunks can launch their gather ---
359
+
360
+ with record_function("muon::wait_gather"):
361
+ work.wait()
362
+ _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
363
+ param_to_state, rank)
364
+ del recv_buf
365
+
366
+ # Stage 3: Newton-Schulz orthogonalization.
367
+ with record_function("muon::newton_schulz"):
368
+ computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
369
+ gathered_grads.clear()
370
+
371
+ # Stages 4-5: launch async scatter.
372
+ with record_function("muon::launch_scatter"):
373
+ work, recv_buf, scattered_us, recv_counts = _launch_scatter(
374
+ params, owned_params, param_to_state, rank, num_ranks,
375
+ process_group, computed_us)
376
+ computed_us.clear()
377
+
378
+ yield # --- YIELD 2: other chunks can launch their scatter ---
379
+
380
+ with record_function("muon::wait_scatter"):
381
+ work.wait()
382
+ _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
383
+ scattered_us)
384
+ del recv_buf
385
+
386
+ # Stage 6: apply parameter updates.
387
+ with record_function("muon::update_params"):
388
+ _update_params(params, param_to_state, rank, scattered_us, lr,
389
+ weight_decay)
390
+ scattered_us.clear()
build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
12
+ """
13
+ Parse a parameter name to check if it is a query/key projection layer
14
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
15
+
16
+ Returns:
17
+ (kind, layer_idx) or (None, -1) if not matched.
18
+
19
+ Example:
20
+ 'model.3.attn.wq.weight' -> ('wq', 3)
21
+ 'model.5.attn.wk.weight' -> ('wk', 5)
22
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
23
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
+ 'model.4.attn.v_proj.weight' -> (None, -1)
25
+ """
26
+ parts = name.split('.')
27
+ if len(parts) < 3:
28
+ return None, -1
29
+
30
+ kind = parts[-2]
31
+
32
+ layer_idx = -1
33
+ for part in reversed(parts):
34
+ if part.isdigit():
35
+ layer_idx = int(part)
36
+ break
37
+
38
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
39
+ return kind, layer_idx
40
+
41
+ return None, -1
42
+
43
+
44
+ @dataclass
45
+ class QKClipInfo:
46
+ """Per-parameter dynamic info computed from config + runtime logits."""
47
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
48
+ indices: list[int] # which heads to consider for clipping
49
+ head_dim: int # from config
50
+ threshold: float # from config
51
+ logit: torch.Tensor | None
52
+
53
+
54
+ def get_qk_clip_info(clip_config, n, qk_logits):
55
+ """Extract QK clipping info for a named parameter.
56
+
57
+ Args:
58
+ clip_config: QK clipping configuration dict (or None).
59
+ n: Parameter name string.
60
+ qk_logits: Dict mapping layer indices to logit tensors (or None).
61
+
62
+ Returns:
63
+ QKClipInfo instance with clipping configuration for this parameter.
64
+ """
65
+ if clip_config is None:
66
+ return None
67
+
68
+ head_dim = clip_config.get('head_dim')
69
+ threshold = clip_config.get('threshold')
70
+ kind, layer_idx = parse_qk_layer(n)
71
+
72
+ logit, indices = None, []
73
+ if qk_logits is not None and kind is not None:
74
+ logit = qk_logits[layer_idx]
75
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
76
+ indices = clip_config.get(indices_key, []) or []
77
+
78
+ if isinstance(logit, DTensor):
79
+ # In TP settings, qk_logits may be DTensor
80
+ # We convert it to full tensor here for simplicity
81
+ logit = logit.full_tensor()
82
+
83
+ return QKClipInfo(
84
+ kind=kind,
85
+ indices=indices,
86
+ head_dim=head_dim,
87
+ threshold=threshold,
88
+ logit=logit,
89
+ )
90
+
91
+
92
+ def compute_scales(p, qk_clip_state):
93
+ """Compute per-head scaling factors for QK clipping.
94
+
95
+ Returns scales tensor if any head exceeds threshold, else None.
96
+ """
97
+ kind = qk_clip_state.kind
98
+ indices = qk_clip_state.indices
99
+ head_dim = qk_clip_state.head_dim
100
+ threshold = qk_clip_state.threshold
101
+ logit = qk_clip_state.logit
102
+
103
+ H_global = p.shape[0] // head_dim
104
+ scales_full = torch.ones(H_global, device=p.data.device)
105
+ scaling = 0
106
+
107
+ for logit_idx, head_idx in enumerate(indices):
108
+ v_ele = float(logit[logit_idx])
109
+ if v_ele > threshold:
110
+ new_scale = math.sqrt(threshold / v_ele)
111
+ if new_scale < scales_full[head_idx]:
112
+ scales_full[head_idx] = new_scale
113
+ logger.info(
114
+ f"[{kind}] Head {head_idx} exceeded threshold "
115
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
+ )
117
+ scaling += 1
118
+
119
+ return scales_full if scaling > 0 else None
120
+
121
+
122
+ def qk_clip(p, scales, head_dim):
123
+ """Apply per-head scaling to a Q/K projection weight matrix."""
124
+ if isinstance(p, torch.nn.Parameter):
125
+ W = p.data.view(-1, head_dim, p.data.shape[1])
126
+ W.mul_(scales.view(-1, 1, 1))
127
+ else:
128
+ W = p.view(-1, head_dim, p.shape[1])
129
+ W.mul_(scales.view(-1, 1, 1))
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_7aef62f_dirty
3
+ ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_7aef62f_dirty::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_06a260a_dirty.abi3.so → _optimizer_7aef62f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497
3
  size 1866400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00e9d9e1c2306badb97c3b8f2454a47d6335a302101a38c804ad3c7b075168cc
3
  size 1866400
build/torch210-cxx11-rocm70-x86_64-linux/adamw.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch.distributed.tensor import DTensor
6
+
7
+
8
+ def fused_adamw(
9
+ params: list[torch.Tensor],
10
+ grads: list[torch.Tensor],
11
+ exp_avgs: list[torch.Tensor],
12
+ exp_avg_sqs: list[torch.Tensor],
13
+ max_exp_avg_sqs: list[torch.Tensor],
14
+ state_steps: list[torch.Tensor],
15
+ amsgrad: bool,
16
+ beta1: float,
17
+ beta2: float,
18
+ lr: float | torch.Tensor,
19
+ weight_decay: float,
20
+ eps: float,
21
+ maximize: bool,
22
+ ) -> None:
23
+ if not params:
24
+ return
25
+
26
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
27
+ # treating it as a scalar.
28
+ lr_dict: dict | None = ({
29
+ lr.device: lr
30
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
31
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
32
+ [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
33
+ state_steps] # type: ignore[list-item]
34
+ )
35
+ for (device, _), (
36
+ (
37
+ device_params_,
38
+ device_grads_,
39
+ device_exp_avgs_,
40
+ device_exp_avg_sqs_,
41
+ device_max_exp_avg_sqs,
42
+ device_state_steps_,
43
+ ),
44
+ _,
45
+ ) in grouped_tensors.items():
46
+ device_params = cast(list[torch.Tensor], device_params_)
47
+ device_grads = cast(list[torch.Tensor], device_grads_)
48
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
49
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
50
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
51
+
52
+ if lr_dict is not None and device not in lr_dict:
53
+ lr_dict[device] = lr.to(
54
+ device=device, non_blocking=True) # type: ignore[union-attr]
55
+ lr = lr_dict[device]
56
+ torch._foreach_add_(device_state_steps, 1)
57
+ func = torch._fused_adamw_
58
+ func(
59
+ device_params,
60
+ device_grads,
61
+ device_exp_avgs,
62
+ device_exp_avg_sqs,
63
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
64
+ device_state_steps,
65
+ amsgrad=amsgrad,
66
+ lr=lr, # type: ignore[arg-type]
67
+ beta1=beta1,
68
+ beta2=beta2,
69
+ weight_decay=weight_decay,
70
+ eps=eps,
71
+ maximize=maximize,
72
+ )
73
+
74
+
75
+ def step_adamw_params(optimizer_state, params, group):
76
+ """Run fused AdamW on a list of parameters sharing the same placement.
77
+
78
+ Args:
79
+ optimizer_state: The optimizer's state dict (self.state in Muon).
80
+ params: List of parameters to update.
81
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
+ """
83
+ params_with_grads = []
84
+ grads = []
85
+ moment1 = []
86
+ moment2 = []
87
+ max_exp_avg_sqs = []
88
+ state_steps = []
89
+ lr = group["lr"]
90
+ beta1, beta2 = group["adamw_betas"]
91
+ eps = group["adamw_eps"]
92
+ weight_decay = group["weight_decay"]
93
+
94
+ for p in params:
95
+ g = p.grad
96
+ if g is None:
97
+ continue
98
+ state = optimizer_state[p]
99
+ params_with_grads.append(p)
100
+ grads.append(g)
101
+ if "step" not in state:
102
+ state["step"] = (torch.zeros((),
103
+ dtype=torch.float32,
104
+ device=p.device))
105
+ state["moment1"] = torch.zeros_like(g)
106
+ state["moment2"] = torch.zeros_like(g)
107
+ moment1.append(state["moment1"])
108
+ moment2.append(state["moment2"])
109
+ if not isinstance(state["step"], torch.Tensor):
110
+ step_tensor = torch.tensor(state["step"],
111
+ dtype=torch.float32,
112
+ device=p.device)
113
+ else:
114
+ step_tensor = state["step"]
115
+ state_steps.append(step_tensor)
116
+
117
+ fused_adamw(
118
+ params_with_grads,
119
+ grads,
120
+ moment1,
121
+ moment2,
122
+ max_exp_avg_sqs,
123
+ state_steps,
124
+ amsgrad=False,
125
+ beta1=beta1,
126
+ beta2=beta2,
127
+ lr=lr,
128
+ weight_decay=weight_decay,
129
+ eps=eps,
130
+ maximize=False,
131
+ )
132
+
133
+
134
+ def step_adamw(optimizer_state, group):
135
+ """Dispatch AdamW step, grouping parameters by type and placement.
136
+
137
+ Args:
138
+ optimizer_state: The optimizer's state dict (self.state in Muon).
139
+ group: Parameter group dict.
140
+ """
141
+ params = group["params"]
142
+
143
+ # group params with its type and placement
144
+ placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
+ for p in params:
146
+ match p:
147
+ case DTensor():
148
+ placement_to_params[tuple([p.placements,
149
+ p.device_mesh])].append(p)
150
+ case torch.Tensor():
151
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
152
+
153
+ for group_params in placement_to_params.values():
154
+ step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class _Task:
8
+ """Internal: wraps a generator, advances one yield at a time."""
9
+
10
+ def __init__(self, generator: Generator[None, None, None], index: int):
11
+ self._generator = generator
12
+ self._index = index
13
+ self._steps_completed = 0
14
+ self.step() # run to first yield
15
+
16
+ def step(self) -> bool:
17
+ try:
18
+ next(self._generator)
19
+ self._steps_completed += 1
20
+ logger.debug("pipeline[%d] completed stage %d", self._index,
21
+ self._steps_completed)
22
+ return True
23
+ except StopIteration:
24
+ logger.debug("pipeline[%d] finished after %d stages", self._index,
25
+ self._steps_completed)
26
+ return False
27
+
28
+ def close(self):
29
+ self._generator.close()
30
+
31
+
32
+ def run_pipeline(
33
+ pipelines: Generator[Generator[None, None, None], None, None],
34
+ max_concurrent: int,
35
+ ) -> None:
36
+ """Run generator-based pipelines with bounded concurrency.
37
+
38
+ Each pipeline is a generator that yields at stage boundaries.
39
+ The runtime interleaves pipelines so communication and computation
40
+ overlap across chunks.
41
+ """
42
+ if max_concurrent <= 0:
43
+ raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
+
45
+ have_new = True
46
+ task_index = 0
47
+ previous_tasks: list[_Task] = []
48
+
49
+ try:
50
+ while have_new or previous_tasks:
51
+ running_tasks: list[_Task] = []
52
+
53
+ # Admit one new pipeline per iteration (staggered admission).
54
+ # Admitting one at a time ensures that while chunk N does NS
55
+ # compute on the default stream, chunk N+1's NCCL all-to-all
56
+ # runs concurrently on the NCCL stream — creating real
57
+ # communication/computation overlap on the GPU.
58
+ if have_new and len(previous_tasks) < max_concurrent:
59
+ try:
60
+ gen = next(pipelines)
61
+ task = _Task(gen, task_index)
62
+ task_index += 1
63
+ running_tasks.append(task)
64
+ except StopIteration:
65
+ have_new = False
66
+
67
+ # Advance every previously-yielded task by one step.
68
+ for task in previous_tasks:
69
+ if task.step():
70
+ running_tasks.append(task)
71
+
72
+ previous_tasks = running_tasks
73
+ except BaseException:
74
+ # Clean up all in-flight generators to release GPU resources.
75
+ for task in previous_tasks:
76
+ task.close()
77
+ raise
build/torch210-cxx11-rocm70-x86_64-linux/core.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed import ProcessGroup
7
+ from torch.distributed.tensor import DTensor
8
+
9
+
10
+ @dataclass
11
+ class _muon_state:
12
+ worker_rank: int
13
+ process_group: ProcessGroup
14
+ rank_indices: dict[int, tuple] # local_rank -> per-dim indices
15
+ rank_numels: dict[int, int] # local_rank -> numel
16
+ name: str
17
+ qk_clip_state: torch.Tensor | None = None
18
+
19
+
20
+ def update_g(optimizer_state, p, g, group, momentum):
21
+ """Apply momentum update to gradient.
22
+
23
+ Args:
24
+ optimizer_state: The optimizer's state dict (self.state in Muon).
25
+ p: Parameter tensor.
26
+ g: Gradient tensor.
27
+ group: Parameter group dict.
28
+ momentum: Momentum coefficient.
29
+
30
+ Returns:
31
+ Momentum-updated gradient tensor.
32
+ """
33
+ state = optimizer_state[p]
34
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
+ torch.add(g, buf, alpha=momentum, out=buf)
36
+ if group["nesterov"]:
37
+ g.add_(buf, alpha=momentum)
38
+ return g
39
+ return buf
40
+
41
+
42
+ def update_p(p, u, lr, adjusted_lr, weight_decay):
43
+ """Apply weight decay and orthogonalized update to parameter.
44
+
45
+ Args:
46
+ p: Parameter (torch.nn.Parameter or DTensor).
47
+ u: Orthogonalized update tensor.
48
+ lr: Base learning rate.
49
+ adjusted_lr: Size-adjusted learning rate.
50
+ weight_decay: Weight decay coefficient.
51
+ """
52
+ if isinstance(p, torch.nn.Parameter):
53
+ # apply weight decay
54
+ p.data.mul_(1 - lr * weight_decay)
55
+ # apply update
56
+ p.data.add_(u, alpha=-adjusted_lr)
57
+ else:
58
+ p.mul_(1 - lr * weight_decay)
59
+ p.add_(u, alpha=-adjusted_lr)
60
+
61
+
62
+ def adjust_lr_for_muon(lr, param_shape):
63
+ """Scale learning rate based on parameter matrix dimensions.
64
+
65
+ Args:
66
+ lr: Base learning rate.
67
+ param_shape: Shape of the parameter tensor.
68
+
69
+ Returns:
70
+ Adjusted learning rate.
71
+ """
72
+ A, B = param_shape[:2]
73
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
74
+ # as described in the paper
75
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
76
+ adjusted_lr = lr * adjusted_ratio
77
+ return adjusted_lr
78
+
79
+
80
+ def default_is_muon(name, x, expert_keys=None):
81
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
+ if any(key in name for key in skip_keys):
83
+ return False
84
+ effective_ndim = x.ndim
85
+ if expert_keys and any(key in name for key in expert_keys):
86
+ effective_ndim -= 1
87
+ return effective_ndim >= 2
88
+
89
+
90
+ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
91
+ if is_muon_func is None:
92
+ is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
+
94
+ muon_params, muon_names = [], []
95
+ non_muon_params = []
96
+
97
+ for n, p in model.named_parameters():
98
+ if not p.requires_grad:
99
+ continue
100
+ if is_muon_func(n, p):
101
+ muon_params.append(p)
102
+ muon_names.append(n)
103
+ else:
104
+ non_muon_params.append(p)
105
+
106
+ return [
107
+ {
108
+ "params": muon_params,
109
+ "names": muon_names,
110
+ "use_muon": True,
111
+ },
112
+ {
113
+ "params": non_muon_params,
114
+ "use_muon": False,
115
+ },
116
+ ]
build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py CHANGED
@@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
  _StridedShard)
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  def get_slices_of_dtensor(
11
  target: DTensor | torch.Tensor,
12
  local_rank: int,
13
  shard_mesh: DeviceMesh,
14
  shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
  """
17
- Get the slice of local tensor for a given rank from a tensor.
 
 
 
 
 
18
  Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
  shard_placements (tuple[Placement]): The shard placements.
23
- """
24
 
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
 
 
 
 
26
 
27
  # find the global rank of the local rank in the shard mesh
28
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
@@ -34,34 +52,75 @@ def get_slices_of_dtensor(
34
 
35
  assert len(rank_coords) == len(shard_placements)
36
 
 
 
 
 
37
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
  # left-to-right sharding. This is ensured by the sorting logic of
39
  # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
 
44
- num_ranks = shard_mesh.mesh.shape[i]
 
45
 
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
 
 
 
48
 
49
- if dim_size % num_ranks != 0:
50
  raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- return tuple(slices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
@@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
71
  def construct_shard_mesh(
72
  placements: tuple[Placement],
73
  mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
 
81
- assert mesh.mesh.device.type == 'cpu'
 
 
82
 
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
 
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
 
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
 
121
 
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
 
 
124
 
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
 
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
  else:
137
  shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
  assert len(shard_placements) == len(set(shard_placements))
142
 
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
  )
171
 
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
7
  _StridedShard)
8
 
9
 
10
+ def _is_shard(placement: Placement) -> bool:
11
+ """Check if a placement is a shard type (Shard or _StridedShard).
12
+
13
+ In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
+ ``placement.is_shard()`` returns False for _StridedShard. This helper
15
+ handles both old and new hierarchies.
16
+ """
17
+ return isinstance(placement, (Shard, _StridedShard))
18
+
19
+
20
  def get_slices_of_dtensor(
21
  target: DTensor | torch.Tensor,
22
  local_rank: int,
23
  shard_mesh: DeviceMesh,
24
  shard_placements: tuple[Placement],
25
+ ) -> tuple[slice | torch.Tensor, ...]:
26
  """
27
+ Get per-dimension indices for a given rank's shard of the target tensor.
28
+
29
+ Uses ``Shard.local_shard_size_and_offset`` and
30
+ ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
+ both contiguous and strided (non-contiguous) sharding.
32
+
33
  Args:
34
+ target (DTensor | torch.Tensor): The target tensor (for its shape).
35
+ local_rank (int): The local rank within the shard group.
36
+ shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
  shard_placements (tuple[Placement]): The shard placements.
 
38
 
39
+ Returns:
40
+ A tuple of indices (one per tensor dim). Each element is either:
41
+ - A ``slice`` (for contiguous or unsharded dims)
42
+ - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
+ """
44
 
45
  # find the global rank of the local rank in the shard mesh
46
  rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
 
52
 
53
  assert len(rank_coords) == len(shard_placements)
54
 
55
+ # Track per-shard-dim indices.
56
+ # None means "not yet sharded on this dim".
57
+ dim_indices: dict[int, torch.Tensor] = {}
58
+
59
  # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
  # left-to-right sharding. This is ensured by the sorting logic of
61
  # construct_shard_mesh function.
62
+ for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
+ zip(rank_coords, shard_placements)):
64
+ assert _is_shard(placement)
65
 
66
+ num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
+ shard_dim = placement.dim
68
 
69
+ # Current effective size on this dim (may already be sub-sharded)
70
+ if shard_dim in dim_indices:
71
+ curr_size = len(dim_indices[shard_dim])
72
+ else:
73
+ curr_size = target.size()[shard_dim]
74
 
75
+ if curr_size % num_chunks != 0:
76
  raise NotImplementedError(
77
+ f"Dimension size {curr_size} is not divisible "
78
+ f"by number of ranks {num_chunks} for shard "
79
+ f"placement on dim {shard_dim}. (shape: {target.shape})")
80
+
81
+ # Compute indices for this level of sharding
82
+ if isinstance(placement, _StridedShard):
83
+ _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
84
+ placement,
85
+ curr_size,
86
+ num_chunks,
87
+ rank_coord,
88
+ return_first_offset=False)
89
+ new_indices = torch.tensor(offsets, dtype=torch.long)
90
+ else:
91
+ shard_size, offset = Shard.local_shard_size_and_offset(
92
+ curr_size, num_chunks, rank_coord)
93
+ new_indices = torch.arange(offset,
94
+ offset + shard_size,
95
+ dtype=torch.long)
96
+
97
+ # Compose with previous indices on this dim
98
+ if shard_dim in dim_indices:
99
+ dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
100
+ else:
101
+ dim_indices[shard_dim] = new_indices
102
 
103
+ # Build result tuple
104
+ result: list[slice | torch.Tensor] = []
105
+ for d in range(len(target.size())):
106
+ if d not in dim_indices:
107
+ result.append(slice(None))
108
+ else:
109
+ indices = dim_indices[d]
110
+ # Convert contiguous indices to slice for efficiency
111
+ if len(indices) > 0:
112
+ start = indices[0].item()
113
+ expected = torch.arange(start,
114
+ start + len(indices),
115
+ dtype=torch.long)
116
+ if torch.equal(indices, expected):
117
+ result.append(slice(start, start + len(indices)))
118
+ else:
119
+ result.append(indices)
120
+ else:
121
+ result.append(slice(0, 0))
122
+
123
+ return tuple(result)
124
 
125
 
126
  _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
 
130
  def construct_shard_mesh(
131
  placements: tuple[Placement],
132
  mesh: DeviceMesh,
133
+ ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
134
+ """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
 
 
 
 
135
 
136
+ Given a DTensor's placements and device mesh, extracts the "shard group"
137
+ — the set of ranks that together hold all shards of the same replica —
138
+ and creates a ProcessGroup for all-to-all among them.
139
 
140
+ Steps:
141
+ 1. Sort placements: Replicate first, then Shard by (dim, granularity).
142
+ 2. Permute the mesh tensor to match the sorted order.
143
+ 3. Collapse Replicate dims list of shard sub-meshes (one per replica).
144
+ 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ Example — 8 GPUs, mesh shape (2, 2, 2),
147
+ placements ``[Shard(0), Replicate, _StridedShard(0)]``::
 
 
148
 
149
+ Step 1 Sort: [Replicate, _StridedShard(0), Shard(0)]
150
+ Permutation: [1, 2, 0]
151
 
152
+ Step 2 Permute mesh dims by [1, 2, 0]:
153
+ Original: Permuted:
154
+ [[[0,1],[2,3]], [[[0,2],[1,3]],
155
+ [[4,5],[6,7]]] [[4,6],[5,7]]]
156
 
157
+ Step 3 Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
158
+ sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
159
+ sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
160
+ shard_placements = (_StridedShard(0), Shard(0))
161
 
162
+ Step 4 Rank 0 → ProcessGroup([0,1,4,5])
163
+ Rank 2 ProcessGroup([2,3,6,7])
164
+
165
+ Returns:
166
+ ``(shard_mesh, process_group, shard_placements)``
167
+ """
168
+ my_rank = dist.get_rank()
169
+ assert mesh.mesh.device.type == 'cpu'
170
+
171
+ # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
172
+ # This avoids a non-collective dist.new_group() call, which would
173
+ # deadlock when only a subset of ranks call this function (e.g. expert
174
+ # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
175
+ if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
176
+ key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
177
+ if key not in _ranks_to_dist_cache:
178
+ _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
179
+ return (*_ranks_to_dist_cache[key], tuple(placements))
180
+
181
+ mesh_tensor = mesh.mesh.clone()
182
+
183
+ # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
184
+ # _StridedShard comes BEFORE regular Shard on the same dim so that
185
+ # get_slices_of_dtensor applies the outer sharding first, matching
186
+ # DTensor's left-to-right (outer-to-inner) composition order.
187
+ def _sort_key(item):
188
+ index, placement = item
189
+ assert not placement.is_partial(), "Partial placement not supported"
190
+ if placement.is_replicate():
191
+ return (-1, 0, index)
192
+ assert _is_shard(placement), f"Unsupported: {type(placement)}"
193
+ split = (-1 / placement.split_factor if isinstance(
194
+ placement, _StridedShard) else 0)
195
+ return (placement.dim, split, index)
196
+
197
+ indexed = sorted(enumerate(placements), key=_sort_key)
198
+ perm, sorted_placements = zip(*indexed)
199
+
200
+ # -- Step 2: Permute mesh to match sorted placement order. --------------
201
+ sorted_mesh = mesh_tensor.permute(perm)
202
+
203
+ # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
204
+ # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
205
+ num_rep = sum(1 for p in sorted_placements if p.is_replicate())
206
+ if num_rep > 0:
207
+ if num_rep > 1:
208
+ sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
209
  shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
210
  else:
211
  shard_meshes = [sorted_mesh]
212
+ shard_placements = sorted_placements[num_rep:]
 
 
213
  assert len(shard_placements) == len(set(shard_placements))
214
 
215
+ # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
216
+ # All ranks must call dist.new_group in the same order, even though each
217
+ # rank only joins one group.
218
+ def _cache_key(t: torch.Tensor) -> tuple:
219
+ return (*t.shape, *t.flatten().tolist())
220
+
221
+ my_key = None
222
+ for sm in shard_meshes:
223
+ key = _cache_key(sm)
224
+ if (my_rank == sm).any().item():
225
+ assert my_key is None, "Rank appears in multiple shard groups"
226
+ my_key = key
227
+ if key not in _ranks_to_dist_cache:
228
+ pg = dist.new_group(sm.flatten().tolist())
229
+ _ranks_to_dist_cache[key] = (
230
+ DeviceMesh(device_type="cuda", mesh=sm),
231
+ pg,
 
 
 
 
 
 
 
 
 
 
232
  )
233
 
234
+ return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/metadata.json CHANGED
@@ -1 +1,3 @@
1
- {"python-depends":[]}
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-rocm70-x86_64-linux/muon.py CHANGED
@@ -1,536 +1,121 @@
1
  import logging
2
- import math
3
  import types
4
  from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
 
 
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
 
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
 
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
 
 
406
 
407
- assert inner_off == block
408
- off += block
409
 
 
410
 
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
  """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
  continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
 
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- return None, -1
 
524
 
 
525
 
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
@@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer):
554
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
  weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
  adamw_lr: The learning rate for the internal AdamW.
559
  adamw_betas: The betas for the internal AdamW.
560
  adamw_eps: The epsilon for the internal AdamW.
@@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer):
564
  - "q_indices" (list[int]): Indices of query heads to consider.
565
  - "k_indices" (list[int]): Indices of key heads to consider.
566
  - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
  this value will be scaled down.
569
  Default is:
570
  {
@@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer):
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
 
 
 
 
 
 
 
587
  """
588
 
589
  def __init__(self,
@@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer):
597
  adamw_eps=1e-8,
598
  none_grad=True,
599
  debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
  warmup_step=5,
607
  chunk_size=-1,
608
  use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
 
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
@@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer):
630
 
631
  super().__init__(params, defaults)
632
 
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
  self.debug = debug
638
- self.clip_config = clip_config
 
 
 
 
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
  self.small_param_numel_threshold = small_param_numel_threshold
 
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
@@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer):
649
 
650
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
 
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
  def get_shard_mesh(self, p):
667
  """
668
  Get the shard mesh for a parameter p on the given rank.
@@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer):
673
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
  p.placements, p.device_mesh)
675
 
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
  return shard_mesh, shard_pg, shard_placements
680
 
681
  def init_state_and_assign_params(self, names, params, group, qk_logits):
@@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer):
694
  total_flops += flops
695
 
696
  if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
 
700
  paired = list(zip(names, params))
701
 
@@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer):
724
 
725
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  param_to_state[id(p)] = _muon_state(
730
  worker_rank=worker_rank,
731
  process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
  name=n,
735
  qk_clip_state=qk_clip_state,
736
  )
737
 
738
  return param_to_state, ordered_params
739
 
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
  for n, p in zip(names, params):
744
  g = p.grad
745
  if g is None:
746
  continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
755
 
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
 
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
 
761
- scales_full = self._compute_scales(
762
  p, qk_clip_state) if qk_clip_state is not None else None
763
  if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
 
766
  def distributed_muon(
767
  self,
@@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer):
770
  group: dict[str, Any],
771
  lr: float,
772
  weight_decay: float,
773
- momentum: float,
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
777
 
 
778
  for n, p in zip(names, params):
779
  g = p.grad
780
  if g is None:
781
  continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
@@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer):
796
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
  steps=group["ns_steps"])
798
 
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
 
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
 
804
- scales_full = self._compute_scales(
805
  p_full, qk_clip_state) if qk_clip_state is not None else None
806
 
807
  if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
  ndims = len(p.device_mesh.mesh.shape)
@@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer):
822
 
823
  p.copy_(p_sharded)
824
 
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
  """
911
  Perform a parallel optimization step using Muon.
912
- """
913
 
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
 
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
 
930
  param_to_state, ordered_params = self.init_state_and_assign_params(
931
  names, params, group, qk_logits)
932
 
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
 
967
  if self.chunk_size == -1:
968
  shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
  elif self.chunk_size > 0:
972
  chunk_size = self.chunk_size
973
  else:
974
  raise ValueError("chunk_size must be -1 or a positive integer.")
975
 
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
 
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
 
1064
  def _step_muon(self, group, qk_logits=None):
1065
  params = group["params"]
@@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer):
1068
  momentum = group["momentum"]
1069
  names = group["names"]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  param_dtensors = []
1072
  name_dtensors = []
1073
 
@@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer):
1083
  group=group,
1084
  lr=lr,
1085
  weight_decay=weight_decay,
1086
- momentum=momentum,
1087
  qk_logits=qk_logits)
1088
  return
1089
 
@@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer):
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
  assert len(dtensors) == len(names)
1125
  for p, n in zip(dtensors, names):
@@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer):
1141
  group=group,
1142
  lr=lr,
1143
  weight_decay=weight_decay,
1144
- momentum=momentum,
1145
  qk_logits=qk_logits,
1146
  )
1147
 
@@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer):
1159
  group,
1160
  lr=lr,
1161
  weight_decay=weight_decay,
1162
- momentum=momentum,
1163
  qk_logits=qk_logits,
1164
  )
1165
 
@@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer):
1170
  group,
1171
  lr=lr,
1172
  weight_decay=weight_decay,
1173
- momentum=momentum,
1174
  qk_logits=qk_logits,
1175
  )
1176
 
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
  @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
@@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer):
1249
  Args:
1250
  closure (Callable, optional): A closure that reevaluates the model
1251
  and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
  (1 / sqrt(head_dim)) * (Q @ K^T).
1256
  """
1257
  loss = None
@@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer):
1263
  if group["use_muon"]:
1264
  self._step_muon(group, qk_logits=qk_logits)
1265
  else:
1266
- self._step_adamw(group)
1267
 
1268
  return loss
 
1
  import logging
 
2
  import types
3
  from collections import defaultdict
4
+ from typing import Any
 
5
 
6
  import torch
7
  import torch.distributed as dist
8
+ from torch.distributed.tensor import DTensor, Replicate, Shard
9
+ from torch.profiler import record_function
10
+
11
+ from .adamw import step_adamw
12
+ from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, update_g, update_p)
15
+ from .distributed.utils import (_is_shard, construct_shard_mesh,
16
+ get_slices_of_dtensor)
17
+ from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
+ _zeropower_via_newtonschulz5)
19
+ from .pipeline import muon_chunk_pipeline
20
+ from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def _expand_expert_params(names, params, expert_keys):
26
+ """Expand expert params by splitting on dim 0 (expert dimension).
 
27
 
28
+ Params whose name matches any key in ``expert_keys`` are treated as
29
+ expert-parallel tensors. Their outermost dimension is the expert
30
+ dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
31
+ ``nn.Parameter`` views so that in-place updates propagate back to
32
+ the original storage.
33
 
34
+ Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
35
+ if they are expert params, their key must be added to ``expert_keys``.
36
 
37
+ The grad must already be set on each expert param (e.g. after momentum).
38
 
39
+ For DTensor expert params, placements that shard on dim 0 (expert dim)
40
+ are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
41
+ preserved: each 2D slice is wrapped as a DTensor on the corresponding
42
+ submesh so the parallel pipeline handles the TP communication.
 
43
  """
44
+ expanded_names = []
45
+ expanded_params = []
46
+
47
+ for n, p in zip(names, params):
48
+ is_expert = expert_keys and any(key in n for key in expert_keys)
49
+ is_dtensor = isinstance(p.data, DTensor)
50
+
51
+ if not is_expert:
52
+ assert p.data.ndim <= 2, (
53
+ f"Param {n} has ndim={p.data.ndim} but does not match "
54
+ f"expert_keys={expert_keys}. If this is an expert param, "
55
+ f"add its key to expert_keys.")
56
+ expanded_names.append(n)
57
+ expanded_params.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ g = p.grad
61
+ assert g is not None, (
62
+ f"Expert param {n} must have grad set before expansion")
63
+
64
+ tp_mesh = None
65
+ tp_placements_2d = None
66
+
67
+ if is_dtensor:
68
+ local_data = p.to_local()
69
+ local_grad = g.to_local() if isinstance(g, DTensor) else g
70
+
71
+ # Find non-dim-0 shard placements (e.g. TP sharding).
72
+ # After splitting on dim 0, Shard(k) becomes Shard(k-1).
73
+ tp_dim_indices = []
74
+ tp_placements_2d = []
75
+ for i, pl in enumerate(p.placements):
76
+ if _is_shard(pl) and pl.dim != 0:
77
+ tp_dim_indices.append(i)
78
+ tp_placements_2d.append(Shard(pl.dim - 1))
79
+
80
+ if tp_dim_indices:
81
+ tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
82
+ for i in tp_dim_indices)
83
+ if len(tp_dim_names) == 1:
84
+ tp_mesh = p.device_mesh[tp_dim_names[0]]
85
+ else:
86
+ tp_mesh = p.device_mesh[tp_dim_names]
87
+ else:
88
+ local_data = p.data
89
+ local_grad = g
90
+
91
+ # Expand: split dim 0, reshape each slice to 2D.
92
+ num_local_experts = local_data.shape[0]
93
+ for i in range(num_local_experts):
94
+ slice_data = local_data[i]
95
+ slice_grad = local_grad[i]
96
+
97
+ if tp_mesh is not None:
98
+ # Wrap as DTensor on TP submesh so the pipeline handles
99
+ # TP communication (gather/scatter across TP ranks).
100
+ dt_data = DTensor.from_local(slice_data,
101
+ device_mesh=tp_mesh,
102
+ placements=tp_placements_2d)
103
+ dt_grad = DTensor.from_local(slice_grad,
104
+ device_mesh=tp_mesh,
105
+ placements=tp_placements_2d)
106
+ expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
107
+ expert_param.grad = dt_grad
108
+ else:
109
+ expert_param = torch.nn.Parameter(slice_data,
110
+ requires_grad=False)
111
+ expert_param.grad = slice_grad
112
 
113
+ expanded_names.append(f"{n}[{i}]")
114
+ expanded_params.append(expert_param)
115
 
116
+ p.grad = None # allow expert grad storage to be freed after pipeline
117
 
118
+ return expanded_names, expanded_params
 
 
 
 
 
 
 
119
 
120
 
121
  class Muon(torch.optim.Optimizer):
 
139
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
140
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
141
  weight_decay: The weight decay for Muon and AdamW.
142
+ Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
143
  adamw_lr: The learning rate for the internal AdamW.
144
  adamw_betas: The betas for the internal AdamW.
145
  adamw_eps: The epsilon for the internal AdamW.
 
149
  - "q_indices" (list[int]): Indices of query heads to consider.
150
  - "k_indices" (list[int]): Indices of key heads to consider.
151
  - "head_dim" (int): Dimensionality of each attention head.
152
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
153
  this value will be scaled down.
154
  Default is:
155
  {
 
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
  small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
+ expert_keys: List of strings to identify expert-parallel parameters.
173
+ If any key appears in a parameter's name, its outermost
174
+ dimension is treated as the expert dimension and expanded
175
+ into per-expert 2D params for Muon. For example,
176
+ ``expert_keys=["experts"]`` matches any param whose name
177
+ contains "experts". 3D+ params not matched by any key
178
+ will raise an error.
179
  """
180
 
181
  def __init__(self,
 
189
  adamw_eps=1e-8,
190
  none_grad=True,
191
  debug=False,
192
+ clip_config=None,
 
 
 
 
 
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
+ small_param_numel_threshold=65536,
197
+ expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
 
218
 
219
  super().__init__(params, defaults)
220
 
 
 
 
 
221
  self.debug = debug
222
+ self.clip_config = clip_config if clip_config is not None else {
223
+ "q_indices": [],
224
+ "k_indices": [],
225
+ "head_dim": 128,
226
+ "threshold": 100,
227
+ }
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
  self.small_param_numel_threshold = small_param_numel_threshold
232
+ self.expert_keys = expert_keys
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
 
239
 
240
  return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def get_shard_mesh(self, p):
243
  """
244
  Get the shard mesh for a parameter p on the given rank.
 
249
  shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
250
  p.placements, p.device_mesh)
251
 
 
 
 
252
  return shard_mesh, shard_pg, shard_placements
253
 
254
  def init_state_and_assign_params(self, names, params, group, qk_logits):
 
267
  total_flops += flops
268
 
269
  if self.debug:
270
+ logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
271
+ total_flops / 1e12)
272
 
273
  paired = list(zip(names, params))
274
 
 
297
 
298
  worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
299
  round_robin = (round_robin + 1) % len(shard_mesh_flattened)
300
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
301
+
302
+ # Precompute per-rank indices and numels for all-to-all.
303
+ rank_indices: dict[int, tuple] = {}
304
+ rank_numels: dict[int, int] = {}
305
+ for r in range(num_ranks):
306
+ indices = get_slices_of_dtensor(p, r, shard_mesh,
307
+ shard_placements)
308
+ rank_indices[r] = indices
309
+ numel = 1
310
+ for idx, dim_size in zip(indices, p.shape):
311
+ if isinstance(idx, slice):
312
+ start, stop, step = idx.indices(dim_size)
313
+ numel *= max(0, (stop - start + (step - 1)) // step)
314
+ else:
315
+ numel *= len(idx)
316
+ rank_numels[r] = numel
317
 
318
  param_to_state[id(p)] = _muon_state(
319
  worker_rank=worker_rank,
320
  process_group=shard_pg,
321
+ rank_indices=rank_indices,
322
+ rank_numels=rank_numels,
323
  name=n,
324
  qk_clip_state=qk_clip_state,
325
  )
326
 
327
  return param_to_state, ordered_params
328
 
329
+ def base(self, names, params, group, lr, weight_decay, qk_logits):
330
+ # Momentum is already applied by _step_muon before this method.
 
331
  for n, p in zip(names, params):
332
  g = p.grad
333
  if g is None:
334
  continue
 
 
 
 
 
335
 
336
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
  steps=group["ns_steps"])
338
 
339
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
+ update_p(p, u, lr, adjusted_lr, weight_decay)
341
 
342
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
343
 
344
+ scales_full = compute_scales(
345
  p, qk_clip_state) if qk_clip_state is not None else None
346
  if scales_full is not None:
347
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
348
 
349
  def distributed_muon(
350
  self,
 
353
  group: dict[str, Any],
354
  lr: float,
355
  weight_decay: float,
 
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
  """ Implementation of Distributed Muon by Liu et al. """
359
 
360
+ # Momentum is already applied by _step_muon before this method.
361
  for n, p in zip(names, params):
362
  g = p.grad
363
  if g is None:
364
  continue
 
 
 
 
 
365
 
366
  # Gather G
367
  if isinstance(p.data, DTensor):
 
374
  u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
  steps=group["ns_steps"])
376
 
377
+ adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
+ update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
+ qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
381
 
382
+ scales_full = compute_scales(
383
  p_full, qk_clip_state) if qk_clip_state is not None else None
384
 
385
  if scales_full is not None:
386
+ qk_clip(p_full, scales_full, qk_clip_state.head_dim)
387
 
388
  if isinstance(p.data, DTensor):
389
  ndims = len(p.device_mesh.mesh.shape)
 
400
 
401
  p.copy_(p_sharded)
402
 
403
+ def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
 
406
 
407
+ Parameters are chunked and each chunk is processed by a
408
+ :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
409
+ interleaves multiple chunks so that communication and computation
410
+ overlap across chunks (the same overlap previously achieved by the
411
+ warmup + main-loop index scheduling).
412
+ """
413
 
414
+ # Momentum is already applied by _step_muon before this method.
 
 
 
 
 
 
 
415
 
416
  param_to_state, ordered_params = self.init_state_and_assign_params(
417
  names, params, group, qk_logits)
418
 
419
+ # Compute local rank for this group's shard process group.
420
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
421
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  if self.chunk_size == -1:
424
  shard_ranks = dist.get_world_size(param_to_state[id(
425
+ ordered_params[0])].process_group)
426
  chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
  elif self.chunk_size > 0:
428
  chunk_size = self.chunk_size
429
  else:
430
  raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
+ def pipelines():
433
+ for start in range(0, len(ordered_params), chunk_size):
434
+ chunk = ordered_params[start:start + chunk_size]
435
+ if chunk:
436
+ yield muon_chunk_pipeline(
437
+ params=chunk,
438
+ param_to_state=param_to_state,
439
+ rank=rank,
440
+ ns_steps=group["ns_steps"],
441
+ lr=lr,
442
+ weight_decay=weight_decay,
443
+ none_grad=group["none_grad"],
444
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ with record_function("muon::barrier"):
447
+ dist.barrier()
448
+ with record_function("muon::pipeline"):
449
+ run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  def _step_muon(self, group, qk_logits=None):
452
  params = group["params"]
 
455
  momentum = group["momentum"]
456
  names = group["names"]
457
 
458
+ # Apply momentum to all params before routing/expansion.
459
+ with record_function("muon::momentum"):
460
+ for n, p in zip(names, params):
461
+ g = p.grad
462
+ if g is None:
463
+ continue
464
+ g = update_g(self.state, p, g, group, momentum)
465
+ p.grad = g
466
+
467
+ # Expand expert params by splitting on dim 0.
468
+ names, params = _expand_expert_params(names, params, self.expert_keys)
469
+
470
  param_dtensors = []
471
  name_dtensors = []
472
 
 
482
  group=group,
483
  lr=lr,
484
  weight_decay=weight_decay,
 
485
  qk_logits=qk_logits)
486
  return
487
 
 
517
  # and run parallel Muon on each group.
518
 
519
  placement_to_params = defaultdict(lambda: ([], []))
 
520
 
521
  assert len(dtensors) == len(names)
522
  for p, n in zip(dtensors, names):
 
538
  group=group,
539
  lr=lr,
540
  weight_decay=weight_decay,
 
541
  qk_logits=qk_logits,
542
  )
543
 
 
555
  group,
556
  lr=lr,
557
  weight_decay=weight_decay,
 
558
  qk_logits=qk_logits,
559
  )
560
 
 
565
  group,
566
  lr=lr,
567
  weight_decay=weight_decay,
 
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
 
575
  Args:
576
  closure (Callable, optional): A closure that reevaluates the model
577
  and returns the loss.
578
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
579
+ to 1D tensors of shape (num_heads,), representing the maximum
580
+ QK logits across all tokens, computed as
581
  (1 / sqrt(head_dim)) * (Q @ K^T).
582
  """
583
  loss = None
 
589
  if group["use_muon"]:
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
592
+ step_adamw(self.state, group)
593
 
594
  return loss
build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .matmul_transpose_triton import matmul_transpose_assign
4
+
5
+ COMM_DTYPE = torch.bfloat16
6
+ DEFAULT_CHUNK_SIZE_RATIO = 4
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ # Muon's Newton–Schulz iteration causes high variance in singular values
12
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
13
+ @torch.no_grad()
14
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
+ def _zeropower_via_newtonschulz5(G, steps):
16
+ """
17
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
+ performance at all relative to UV^T, where USV^T = G is the SVD.
24
+ """
25
+ assert len(G.shape) == 2
26
+ assert G.dtype == COMM_DTYPE
27
+ X = G # no manual typecast
28
+
29
+ if G.size(0) > G.size(1):
30
+ X = X.T
31
+ # Ensure spectral norm is at most 1
32
+ X = X / (X.norm() + 1e-7)
33
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
+ # Perform the NS iterations
36
+ for a, b, c in [
37
+ (4.0848, -6.8946, 2.9270),
38
+ (3.9505, -6.3029, 2.6377),
39
+ (3.7418, -5.5913, 2.3037),
40
+ (2.8769, -3.1427, 1.2046),
41
+ (2.8366, -3.0525, 1.2012),
42
+ ]:
43
+ matmul_transpose_assign(X, buf1)
44
+ matmul_transpose_assign(buf1, buf2)
45
+ buf1.mul_(b).add_(buf2, alpha=c)
46
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
47
+
48
+ if G.size(0) > G.size(1):
49
+ X = X.T
50
+ return X
build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ from .core import _muon_state, adjust_lr_for_muon, update_p
10
+ from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
+ from .qk_clip import compute_scales
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ======================================================================
16
+ # Stage helpers
17
+ # ======================================================================
18
+
19
+
20
+ def _launch_gather(
21
+ params: list[DTensor],
22
+ owned_params: list[DTensor],
23
+ param_to_state: dict[int, _muon_state],
24
+ rank: int,
25
+ num_ranks: int,
26
+ process_group: dist.ProcessGroup,
27
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
+ """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
+
30
+ Returns:
31
+ work: Async operation handle.
32
+ recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
+ gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
+ ``None`` for non-owned.
35
+ recv_counts: Per-source-rank element counts.
36
+ """
37
+ # Allocate gathered-grad buffers
38
+ gathered_grads: dict[int, torch.Tensor | None] = {}
39
+ for p in params:
40
+ state = param_to_state[id(p)]
41
+ if rank == state.worker_rank:
42
+ gathered_grads[id(p)] = torch.empty(p.shape,
43
+ dtype=COMM_DTYPE,
44
+ device="cuda")
45
+ else:
46
+ gathered_grads[id(p)] = None
47
+
48
+ # Build send buffer
49
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
+ send_counts = [0] * num_ranks
51
+
52
+ for p in params:
53
+ state = param_to_state[id(p)]
54
+ dst = state.worker_rank
55
+ assert dst < num_ranks
56
+ shard_elems = state.rank_numels[rank]
57
+ g = p.grad
58
+ g = g.to_local().to(COMM_DTYPE).contiguous()
59
+ assert g.numel() == shard_elems
60
+ per_dst[dst].append(g.view(-1))
61
+ send_counts[dst] += shard_elems
62
+
63
+ assert any(
64
+ len(v) > 0 for v in
65
+ per_dst), "At least one destination rank must receive a sharded tensor"
66
+ per_dst_flat = [t for dst in per_dst for t in dst]
67
+ send_buf = torch.cat(per_dst_flat, dim=0)
68
+
69
+ # Build recv buffer
70
+ recv_counts = [0] * num_ranks
71
+ for src in range(num_ranks):
72
+ total = 0
73
+ for p in owned_params:
74
+ state = param_to_state[id(p)]
75
+ assert state.worker_rank == rank
76
+ total += state.rank_numels[src]
77
+ recv_counts[src] = total
78
+
79
+ recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
80
+
81
+ # Launch async all-to-all
82
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
83
+ f"recv_buf size: {recv_buf.numel()}, "
84
+ f"recv_counts: {recv_counts}, "
85
+ f"send_counts: {send_counts}, "
86
+ f"process_group: {str(process_group)}")
87
+ work = dist.all_to_all_single(
88
+ recv_buf,
89
+ send_buf,
90
+ output_split_sizes=recv_counts,
91
+ input_split_sizes=send_counts,
92
+ group=process_group,
93
+ async_op=True,
94
+ )
95
+
96
+ return work, recv_buf, gathered_grads, recv_counts
97
+
98
+
99
+ def _complete_gather(
100
+ recv_buf: torch.Tensor,
101
+ recv_counts: list[int],
102
+ owned_params: list[DTensor],
103
+ gathered_grads: dict[int, torch.Tensor | None],
104
+ param_to_state: dict[int, _muon_state],
105
+ rank: int,
106
+ ) -> None:
107
+ """Reconstruct gathered grads from the recv buffer (in-place)."""
108
+ off = 0
109
+ for src in range(len(recv_counts)):
110
+ if recv_counts[src] == 0:
111
+ continue
112
+
113
+ block = recv_counts[src]
114
+ inner_off = 0
115
+ for p in owned_params:
116
+ state = param_to_state[id(p)]
117
+ assert state.worker_rank == rank
118
+
119
+ indices = state.rank_indices[src]
120
+
121
+ shard_view = gathered_grads[id(p)][indices]
122
+ n = shard_view.numel()
123
+ assert n > 0
124
+
125
+ sg = recv_buf.narrow(0, off + inner_off, n)
126
+ sg = sg.reshape(shard_view.shape)
127
+ gathered_grads[id(p)][indices] = sg
128
+
129
+ inner_off += n
130
+ assert inner_off == block
131
+ off += block
132
+
133
+
134
+ def _compute_ns(
135
+ owned_params: list[DTensor],
136
+ gathered_grads: dict[int, torch.Tensor | None],
137
+ ns_steps: int,
138
+ ) -> dict[int, torch.Tensor | None]:
139
+ """Run Newton-Schulz orthogonalization on owned parameters.
140
+
141
+ Returns:
142
+ computed_us: ``{id(p): orthogonalized_update}`` for owned params.
143
+ """
144
+ computed_us: dict[int, torch.Tensor | None] = {}
145
+ for p in owned_params:
146
+ u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
+ gathered_grads[id(p)] = None # free gathered grad
148
+ computed_us[id(p)] = u
149
+ return computed_us
150
+
151
+
152
+ def _launch_scatter(
153
+ params: list[DTensor],
154
+ owned_params: list[DTensor],
155
+ param_to_state: dict[int, _muon_state],
156
+ rank: int,
157
+ num_ranks: int,
158
+ process_group: dist.ProcessGroup,
159
+ computed_us: dict[int, torch.Tensor | None],
160
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
161
+ """Allocate scatter buffers, build send/recv, and launch async all-to-all.
162
+
163
+ Returns:
164
+ work: Async operation handle.
165
+ recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
+ scattered_us: ``{id(p): empty_local_tensor}`` for all params.
167
+ recv_counts: Per-source-rank element counts.
168
+ """
169
+ # Allocate scattered-u buffers
170
+ scattered_us: dict[int, torch.Tensor] = {}
171
+ for p in params:
172
+ scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
173
+
174
+ # Build send buffer (from computed_us on owner ranks)
175
+ per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
+ send_counts = [0] * num_ranks
177
+
178
+ if owned_params:
179
+ for p in owned_params:
180
+ state = param_to_state[id(p)]
181
+
182
+ assert computed_us[id(p)] is not None
183
+ u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
+
185
+ total_sent = 0
186
+ for dst_rank in range(num_ranks):
187
+ indices = state.rank_indices[dst_rank]
188
+ su = u_full[indices].flatten()
189
+
190
+ n = su.numel()
191
+ assert n > 0
192
+
193
+ per_dst[dst_rank].append(su)
194
+ send_counts[dst_rank] += n
195
+ total_sent += n
196
+
197
+ assert total_sent == u_full.numel()
198
+
199
+ lengths = [len(v) for v in per_dst]
200
+ if all(l > 0 for l in lengths):
201
+ assert all(
202
+ l == lengths[0] for l in lengths
203
+ ), "All destination ranks must have the same number of sharded tensor"
204
+ per_dst_flat = [t for dst in per_dst for t in dst]
205
+ send_buf = torch.cat(per_dst_flat, dim=0)
206
+ else:
207
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
+
209
+ # Build recv buffer
210
+ recv_counts = [0] * num_ranks
211
+ for src in range(num_ranks):
212
+ total = 0
213
+ for p in params:
214
+ state = param_to_state[id(p)]
215
+ if state.worker_rank != src:
216
+ continue
217
+ total += state.rank_numels[rank]
218
+ recv_counts[src] = total
219
+
220
+ recv_total = sum(recv_counts)
221
+ assert recv_total > 0
222
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
+
224
+ # Launch async all-to-all
225
+ work = dist.all_to_all_single(
226
+ recv_buf,
227
+ send_buf,
228
+ output_split_sizes=recv_counts,
229
+ input_split_sizes=send_counts,
230
+ group=process_group,
231
+ async_op=True,
232
+ )
233
+
234
+ return work, recv_buf, scattered_us, recv_counts
235
+
236
+
237
+ def _complete_scatter(
238
+ recv_buf: torch.Tensor,
239
+ recv_counts: list[int],
240
+ params: list[DTensor],
241
+ param_to_state: dict[int, _muon_state],
242
+ rank: int,
243
+ scattered_us: dict[int, torch.Tensor],
244
+ ) -> None:
245
+ """Copy recv buffer into scattered_us (in-place)."""
246
+ off = 0
247
+ for src in range(len(recv_counts)):
248
+ block = recv_counts[src]
249
+ if block == 0:
250
+ continue
251
+
252
+ inner_off = 0
253
+ for p in params:
254
+ state = param_to_state[id(p)]
255
+ if state.worker_rank != src:
256
+ continue
257
+ n = state.rank_numels[rank]
258
+ assert n > 0
259
+
260
+ flat_local = recv_buf.narrow(0, off + inner_off,
261
+ n).view_as(p.to_local())
262
+ scattered_us[id(p)].copy_(flat_local)
263
+
264
+ inner_off += n
265
+
266
+ assert inner_off == block
267
+ off += block
268
+
269
+
270
+ def _update_params(
271
+ params: list[DTensor],
272
+ param_to_state: dict[int, _muon_state],
273
+ rank: int,
274
+ scattered_us: dict[int, torch.Tensor],
275
+ lr: float,
276
+ weight_decay: float,
277
+ ) -> None:
278
+ """Apply weight decay, Muon update, and optional QK clipping."""
279
+ for p in params:
280
+ state = param_to_state[id(p)]
281
+ u_dtensor = DTensor.from_local(
282
+ scattered_us[id(p)],
283
+ placements=p.placements,
284
+ device_mesh=p.device_mesh,
285
+ )
286
+
287
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
+ update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
289
+
290
+ # QK clipping – applied directly on the local tensor to
291
+ # avoid DTensor sharding-propagation issues with _StridedShard.
292
+ scales_full = compute_scales(
293
+ p,
294
+ state.qk_clip_state) if state.qk_clip_state is not None else None
295
+ if scales_full is not None:
296
+ ratio = p.shape[0] // scales_full.shape[0]
297
+ idx0 = state.rank_indices[rank][0]
298
+ if isinstance(idx0, slice):
299
+ start = idx0.start or 0
300
+ idx0 = torch.arange(start,
301
+ idx0.stop,
302
+ device=scales_full.device)
303
+ row_scales = scales_full[idx0 // ratio]
304
+ p._local_tensor.mul_(row_scales.view(-1, 1))
305
+
306
+
307
+ # ======================================================================
308
+ # Main generator – thin orchestrator that wires stages together.
309
+ # ======================================================================
310
+
311
+
312
+ @torch.no_grad()
313
+ def muon_chunk_pipeline(
314
+ params: list[DTensor],
315
+ param_to_state: dict[int, _muon_state],
316
+ rank: int,
317
+ ns_steps: int,
318
+ lr: float,
319
+ weight_decay: float,
320
+ none_grad: bool,
321
+ ) -> Generator[None, None, None]:
322
+ """Process one chunk of parameters through the full Muon pipeline.
323
+
324
+ Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
325
+
326
+ Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
327
+ that communication and computation overlap across chunks. Async
328
+ communication is launched via ``async_op=True`` and completed after
329
+ the yield with ``work.wait()``.
330
+
331
+ Overlap happens because :func:`run_pipeline` admits one new chunk
332
+ per iteration (staggered admission). While chunk *N* does NS
333
+ compute on the default CUDA stream, chunk *N+1*'s async all-to-all
334
+ runs concurrently on the NCCL stream — no separate ``comm_stream``
335
+ is required.
336
+
337
+ Yields exactly **2** times:
338
+
339
+ 1. After launching async all-to-all gather.
340
+ 2. After launching async all-to-all scatter.
341
+ """
342
+ process_group = param_to_state[id(params[0])].process_group
343
+ num_ranks = dist.get_world_size(group=process_group)
344
+ owned_params = [
345
+ p for p in params if param_to_state[id(p)].worker_rank == rank
346
+ ]
347
+
348
+ # Stages 1-2: launch async gather.
349
+ with record_function("muon::launch_gather"):
350
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
+ params, owned_params, param_to_state, rank, num_ranks,
352
+ process_group)
353
+
354
+ if none_grad:
355
+ for p in params:
356
+ p.grad = None
357
+
358
+ yield # --- YIELD 1: other chunks can launch their gather ---
359
+
360
+ with record_function("muon::wait_gather"):
361
+ work.wait()
362
+ _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
363
+ param_to_state, rank)
364
+ del recv_buf
365
+
366
+ # Stage 3: Newton-Schulz orthogonalization.
367
+ with record_function("muon::newton_schulz"):
368
+ computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
369
+ gathered_grads.clear()
370
+
371
+ # Stages 4-5: launch async scatter.
372
+ with record_function("muon::launch_scatter"):
373
+ work, recv_buf, scattered_us, recv_counts = _launch_scatter(
374
+ params, owned_params, param_to_state, rank, num_ranks,
375
+ process_group, computed_us)
376
+ computed_us.clear()
377
+
378
+ yield # --- YIELD 2: other chunks can launch their scatter ---
379
+
380
+ with record_function("muon::wait_scatter"):
381
+ work.wait()
382
+ _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
383
+ scattered_us)
384
+ del recv_buf
385
+
386
+ # Stage 6: apply parameter updates.
387
+ with record_function("muon::update_params"):
388
+ _update_params(params, param_to_state, rank, scattered_us, lr,
389
+ weight_decay)
390
+ scattered_us.clear()
build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
12
+ """
13
+ Parse a parameter name to check if it is a query/key projection layer
14
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
15
+
16
+ Returns:
17
+ (kind, layer_idx) or (None, -1) if not matched.
18
+
19
+ Example:
20
+ 'model.3.attn.wq.weight' -> ('wq', 3)
21
+ 'model.5.attn.wk.weight' -> ('wk', 5)
22
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
23
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
+ 'model.4.attn.v_proj.weight' -> (None, -1)
25
+ """
26
+ parts = name.split('.')
27
+ if len(parts) < 3:
28
+ return None, -1
29
+
30
+ kind = parts[-2]
31
+
32
+ layer_idx = -1
33
+ for part in reversed(parts):
34
+ if part.isdigit():
35
+ layer_idx = int(part)
36
+ break
37
+
38
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
39
+ return kind, layer_idx
40
+
41
+ return None, -1
42
+
43
+
44
+ @dataclass
45
+ class QKClipInfo:
46
+ """Per-parameter dynamic info computed from config + runtime logits."""
47
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
48
+ indices: list[int] # which heads to consider for clipping
49
+ head_dim: int # from config
50
+ threshold: float # from config
51
+ logit: torch.Tensor | None
52
+
53
+
54
+ def get_qk_clip_info(clip_config, n, qk_logits):
55
+ """Extract QK clipping info for a named parameter.
56
+
57
+ Args:
58
+ clip_config: QK clipping configuration dict (or None).
59
+ n: Parameter name string.
60
+ qk_logits: Dict mapping layer indices to logit tensors (or None).
61
+
62
+ Returns:
63
+ QKClipInfo instance with clipping configuration for this parameter.
64
+ """
65
+ if clip_config is None:
66
+ return None
67
+
68
+ head_dim = clip_config.get('head_dim')
69
+ threshold = clip_config.get('threshold')
70
+ kind, layer_idx = parse_qk_layer(n)
71
+
72
+ logit, indices = None, []
73
+ if qk_logits is not None and kind is not None:
74
+ logit = qk_logits[layer_idx]
75
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
76
+ indices = clip_config.get(indices_key, []) or []
77
+
78
+ if isinstance(logit, DTensor):
79
+ # In TP settings, qk_logits may be DTensor
80
+ # We convert it to full tensor here for simplicity
81
+ logit = logit.full_tensor()
82
+
83
+ return QKClipInfo(
84
+ kind=kind,
85
+ indices=indices,
86
+ head_dim=head_dim,
87
+ threshold=threshold,
88
+ logit=logit,
89
+ )
90
+
91
+
92
+ def compute_scales(p, qk_clip_state):
93
+ """Compute per-head scaling factors for QK clipping.
94
+
95
+ Returns scales tensor if any head exceeds threshold, else None.
96
+ """
97
+ kind = qk_clip_state.kind
98
+ indices = qk_clip_state.indices
99
+ head_dim = qk_clip_state.head_dim
100
+ threshold = qk_clip_state.threshold
101
+ logit = qk_clip_state.logit
102
+
103
+ H_global = p.shape[0] // head_dim
104
+ scales_full = torch.ones(H_global, device=p.data.device)
105
+ scaling = 0
106
+
107
+ for logit_idx, head_idx in enumerate(indices):
108
+ v_ele = float(logit[logit_idx])
109
+ if v_ele > threshold:
110
+ new_scale = math.sqrt(threshold / v_ele)
111
+ if new_scale < scales_full[head_idx]:
112
+ scales_full[head_idx] = new_scale
113
+ logger.info(
114
+ f"[{kind}] Head {head_idx} exceeded threshold "
115
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
+ )
117
+ scaling += 1
118
+
119
+ return scales_full if scaling > 0 else None
120
+
121
+
122
+ def qk_clip(p, scales, head_dim):
123
+ """Apply per-head scaling to a Q/K projection weight matrix."""
124
+ if isinstance(p, torch.nn.Parameter):
125
+ W = p.data.view(-1, head_dim, p.data.shape[1])
126
+ W.mul_(scales.view(-1, 1, 1))
127
+ else:
128
+ W = p.view(-1, head_dim, p.shape[1])
129
+ W.mul_(scales.view(-1, 1, 1))