Kernels
wyldecat Claude Opus 4.6 commited on
Commit
14040eb
Β·
1 Parent(s): 81f49fe

Add optimization docs and update implementation guide [skip-build]

Browse files
Files changed (2) hide show
  1. docs/implementation.md +63 -19
  2. docs/optimizations.md +125 -0
docs/implementation.md CHANGED
@@ -8,11 +8,12 @@ This document explains the internal architecture of the Muon optimizer for revie
8
  2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
9
  3. [Execution Paths](#execution-paths)
10
  4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
11
- 5. [Distributed Utilities](#distributed-utilities)
12
- 6. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization)
13
- 7. [QK Clipping](#qk-clipping)
14
- 8. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters)
15
- 9. [Source File Map](#source-file-map)
 
16
 
17
  ---
18
 
@@ -33,17 +34,19 @@ Users must provide parameter groups with `use_muon=True/False` flags (via `get_d
33
 
34
  ```
35
  _step_muon(group)
 
 
 
36
  |
37
  +-- DTensor, all Replicate placements --> base() (no sharding)
38
- +-- DTensor, numel <= threshold --> distributed_muon() (small param fallback)
39
  +-- DTensor, sharded --> parallel() (pipelined all-to-all)
40
  +-- plain Tensor --> base() (single device)
41
  ```
42
 
43
  Parameters are classified by their DTensor placements:
44
  - **Fully replicated** DTensors and plain tensors use `base()` &mdash; standard single-device Muon.
45
- - **Small sharded** DTensors (below `small_param_numel_threshold`, default 65536) use `distributed_muon()` &mdash; gathers the full tensor via `full_tensor()`, computes the update, then redistributes.
46
- - **Large sharded** DTensors use `parallel()` &mdash; the pipelined all-to-all approach described below.
47
 
48
  ## Execution Paths
49
 
@@ -51,9 +54,9 @@ Parameters are classified by their DTensor placements:
51
 
52
  Straightforward per-parameter loop: momentum update &rarr; Newton-Schulz orthogonalization &rarr; parameter update &rarr; optional QK clipping.
53
 
54
- ### distributed_muon() &mdash; Full Gather
55
 
56
- Each parameter's gradient is gathered to full via `g.full_tensor()`, orthogonalized on every rank, then the updated full parameter is redistributed back to the original sharded placement. Simple but communication-heavy &mdash; used only as a fallback for small parameters.
57
 
58
  ### parallel() &mdash; Pipelined All-to-All
59
 
@@ -171,6 +174,47 @@ Inverse of gather:
171
 
172
  Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  ## Distributed Utilities
175
 
176
  **File:** `distributed/utils.py`
@@ -181,7 +225,7 @@ These utilities solve the problem of mapping from a DTensor's arbitrary sharding
181
 
182
  Given a DTensor's placements and device mesh, this function:
183
 
184
- 1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` after regular `Shard` on the same dim).
185
  2. **Permutes** the mesh accordingly.
186
  3. **Separates** replicate dims from shard dims &mdash; each replicate group gets its own shard sub-mesh.
187
  4. **Creates** a ProcessGroup for the current rank's shard mesh.
@@ -214,7 +258,7 @@ def _is_shard(placement):
214
 
215
  **File:** `newton_schulz.py`
216
 
217
- `_zeropower_via_newtonschulz5()` computes the orthogonal approximation of a matrix using 5 quintic Newton-Schulz iterations with pre-optimized coefficients. The result approximates `US'V^T` where `S'` is near-uniform on `[0.5, 1.5]`, which empirically does not hurt model performance vs. exact `UV^T`.
218
 
219
  Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.
220
 
@@ -248,15 +292,15 @@ Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimi
248
 
249
  | File | Lines | Purpose |
250
  |------|-------|---------|
251
- | `muon.py` | ~525 | Optimizer class, parameter routing, 3 execution paths |
252
- | `pipeline.py` | ~290 | Generator-based parallel pipeline (gather/compute/scatter/update) |
253
  | `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
254
- | `core.py` | ~110 | `_muon_state` dataclass, momentum/update helpers, param grouping |
255
  | `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
256
- | `newton_schulz.py` | ~50 | Newton-Schulz iteration |
257
- | `matmul_transpose_triton.py` | ~120 | Triton kernel for symmetric matmul |
258
- | `qk_clip.py` | ~130 | QK logit clipping |
259
- | `adamw.py` | ~160 | Fused AdamW for non-Muon params |
260
 
261
  ### Dependency Graph
262
 
 
8
  2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
9
  3. [Execution Paths](#execution-paths)
10
  4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
11
+ 5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys)
12
+ 6. [Distributed Utilities](#distributed-utilities)
13
+ 7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization)
14
+ 8. [QK Clipping](#qk-clipping)
15
+ 9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters)
16
+ 10. [Source File Map](#source-file-map)
17
 
18
  ---
19
 
 
34
 
35
  ```
36
  _step_muon(group)
37
+ |
38
+ +-- momentum update (batched _foreach_* ops)
39
+ +-- _expand_expert_params() -- 3D expert params β†’ per-expert 2D views (cached)
40
  |
41
  +-- DTensor, all Replicate placements --> base() (no sharding)
 
42
  +-- DTensor, sharded --> parallel() (pipelined all-to-all)
43
  +-- plain Tensor --> base() (single device)
44
  ```
45
 
46
  Parameters are classified by their DTensor placements:
47
  - **Fully replicated** DTensors and plain tensors use `base()` &mdash; standard single-device Muon.
48
+ - **Sharded** DTensors use `parallel()` &mdash; the pipelined all-to-all approach described below.
49
+ - `distributed_muon()` exists as a **test-only reference implementation** for correctness verification.
50
 
51
  ## Execution Paths
52
 
 
54
 
55
  Straightforward per-parameter loop: momentum update &rarr; Newton-Schulz orthogonalization &rarr; parameter update &rarr; optional QK clipping.
56
 
57
+ ### distributed_muon() &mdash; Full Gather (test-only)
58
 
59
+ Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy &mdash; not used in production.
60
 
61
  ### parallel() &mdash; Pipelined All-to-All
62
 
 
174
 
175
  Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
176
 
177
+ ## MoE Expert Weight Support (`expert_keys`)
178
+
179
+ **File:** `muon.py` &mdash; `_expand_expert_params()`
180
+
181
+ MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling.
182
+
183
+ ### Configuration
184
+
185
+ Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`:
186
+
187
+ ```python
188
+ params = get_default_muon_param_groups(model, expert_keys=["experts"])
189
+ optim = Muon(params, expert_keys=["experts"], ...)
190
+ ```
191
+
192
+ Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration.
193
+
194
+ ### How It Works
195
+
196
+ `_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`:
197
+
198
+ 1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back.
199
+ 2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split).
200
+ 3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** &mdash; the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding.
201
+
202
+ ### Placement-Agnostic Design
203
+
204
+ The expansion logic does not care *why* a dimension is sharded &mdash; only whether it's on dim 0 (consumed by split) or not (preserved on submesh):
205
+
206
+ | Original Placement | After Expansion |
207
+ |-------------------|-----------------|
208
+ | `Shard(0)` (EP) | Consumed by split &rarr; plain tensor |
209
+ | `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh &rarr; 2D DTensor |
210
+ | `Shard(2)` (TP row-wise) | `Shard(1)` on submesh &rarr; 2D DTensor |
211
+ | `Replicate` | Ignored (not a shard) |
212
+ | `_StridedShard(0)` (EFSDP) | Consumed by split &rarr; plain tensor |
213
+
214
+ After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`.
215
+
216
+ For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md).
217
+
218
  ## Distributed Utilities
219
 
220
  **File:** `distributed/utils.py`
 
225
 
226
  Given a DTensor's placements and device mesh, this function:
227
 
228
+ 1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first).
229
  2. **Permutes** the mesh accordingly.
230
  3. **Separates** replicate dims from shard dims &mdash; each replicate group gets its own shard sub-mesh.
231
  4. **Creates** a ProcessGroup for the current rank's shard mesh.
 
258
 
259
  **File:** `newton_schulz.py`
260
 
261
+ `_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method &mdash; quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support.
262
 
263
  Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.
264
 
 
292
 
293
  | File | Lines | Purpose |
294
  |------|-------|---------|
295
+ | `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching |
296
+ | `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) |
297
  | `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
298
+ | `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping |
299
  | `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
300
+ | `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph |
301
+ | `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul |
302
+ | `qk_clip.py` | ~135 | QK logit clipping |
303
+ | `adamw.py` | ~170 | Fused AdamW for non-Muon params |
304
 
305
  ### Dependency Graph
306
 
docs/optimizations.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Performance Optimizations (vs. main)
2
+
3
+ Summary of optimizations on branch `perf/pipelined-distributed-muon-clean` relative to `main`.
4
+
5
+ ---
6
+
7
+ ## 1. Batched Momentum (`core.py`)
8
+
9
+ **Before:** Per-param `update_g()` β€” one `torch.add` + optional `torch.add_` per parameter.
10
+
11
+ **After:** `_batch_pre_ortho()` β€” `_foreach_mul_`, `_foreach_add_` on lists of local tensors (unwrapped from DTensor). Single fused kernel per batch instead of N individual kernels.
12
+
13
+ **Impact:** Eliminates N per-param Python-loop overhead + N small kernel launches. Scales with parameter count.
14
+
15
+ ---
16
+
17
+ ## 2. Pipeline Buffer Packing (`pipeline.py`)
18
+
19
+ ### Gather send buffer
20
+
21
+ **Before:** Per-param `.to(COMM_DTYPE).contiguous()` followed by per-destination `append` to list, then `torch.cat` on the per-dst lists.
22
+
23
+ **After:** Collect all grad slices in destination order in a single pass, then one `torch.cat` call. Avoids intermediate per-destination lists and redundant dtype conversions.
24
+
25
+ ### Scatter send buffer
26
+
27
+ **Before:** Per-param, per-destination-rank: index `u_full[indices].flatten()`, append to per-dst list, then flatten+cat.
28
+
29
+ **After:** Cache `u_full` conversions (avoid redundant `.to()` per dst_rank). Collect all slices in dst order in one pass, single `torch.cat`.
30
+
31
+ **Impact:** Fewer kernel launches, less Python overhead, reduced intermediate allocations.
32
+
33
+ ---
34
+
35
+ ## 3. Zero-Copy Scatter (`pipeline.py`)
36
+
37
+ **Before:** `_launch_scatter` pre-allocates `torch.empty_like(p.to_local())` for every param. `_complete_scatter` copies from recv_buf into these pre-allocated tensors via `copy_()`.
38
+
39
+ **After:** `_complete_scatter` assigns **views** into `recv_buf` directly (via `recv_buf.narrow(...).view_as(...)`). No pre-allocation, no copy. The recv_buf storage stays alive through the views until `_update_params` consumes them.
40
+
41
+ **Impact:** Eliminates N `empty_like` allocations + N `copy_` kernel launches per scatter stage.
42
+
43
+ ---
44
+
45
+ ## 4. Batched Parameter Update (`pipeline.py`)
46
+
47
+ **Before:** Per-param loop calling `update_p()` (which unwraps DTensor, applies weight decay, applies update individually).
48
+
49
+ **After:** Batched using `_foreach_mul_` (weight decay) and `_foreach_add_` (Muon update), grouped by `adjusted_lr` to preserve float32 alpha precision. Single kernel per group instead of per param.
50
+
51
+ **Impact:** Reduces N per-param kernel launches to 1-2 batched kernel launches.
52
+
53
+ ---
54
+
55
+ ## 5. Parallel Metadata Caching (`muon.py`)
56
+
57
+ **Before:** `init_state_and_assign_params()` called every step β€” sorts params by FLOP cost, assigns ownership via round-robin, precomputes per-rank indices/numels for all-to-all.
58
+
59
+ **After:** `_parallel_cache` keyed by `tuple(names)`. First call computes and caches `ordered_names`, `name_to_state`, `rank`, `chunk_size`. Subsequent calls reuse cached metadata, only rebuilding `param_to_state` with current `id(p)` keys (since param objects are stable but ids may change for QK clip updates).
60
+
61
+ **Impact:** Eliminates repeated sorting, mesh construction, and index precomputation on every step.
62
+
63
+ ---
64
+
65
+ ## 6. Expert Param Expansion Caching (`muon.py`)
66
+
67
+ **Before:** `_expand_expert_params()` called every step β€” for each expert param `(E, out, in)`, creates E `nn.Parameter` wrappers (triggers `aten::detach`), indexes data and grad (`aten::select`), and wraps in DTensor for TP.
68
+
69
+ **After:** `_expert_expand_cache` keyed by `tuple(id(p) for p in params)`. Cold path runs `_expand_expert_params` once and caches:
70
+
71
+ - `expanded_names` / `expanded_params` β€” the nn.Parameter wrappers with stable data views
72
+ - `grad_info` β€” per-expert-group metadata (orig param index, num experts, expanded start index, DTensor flag, TP mesh/placements)
73
+
74
+ Hot path reuses cached nn.Parameter objects (data views are stable since optimizer updates happen in-place on the same storage). Only updates `.grad` on each cached expert param by slicing the current step's gradient.
75
+
76
+ **Eliminated on hot path:**
77
+
78
+ - `nn.Parameter()` construction β€” removes `aten::detach`
79
+ - `local_data[i]` data slicing β€” removes half of `aten::select` + `aten::as_strided`
80
+ - `DTensor.from_local()` for data β€” only needed for grad now
81
+ - `is_expert_param()` name matching per step
82
+
83
+ **Still required per step:**
84
+
85
+ - `local_grad[i]` β€” grad tensor changes each step (nesterov)
86
+ - `DTensor.from_local(slice_grad, ...)` β€” for TP expert grads
87
+ - `p.grad = None` β€” freeing original 3D grad storage
88
+
89
+ **Impact:** ~8ms CPU overhead reduction per step at production scale (64 GPUs, 48 local experts).
90
+
91
+ ---
92
+
93
+ ## 7. Newton-Schulz Compile + CUDA Graph (`newton_schulz.py`)
94
+
95
+ **Before:** `_zeropower_via_newtonschulz5()` called directly every time.
96
+
97
+ **After:** `zeropower_via_newtonschulz5()` wrapper with per-shape `torch.compile` caching + CUDA graph (`triton.cudagraphs=True`). Each unique shape gets its own compiled function stored in `_ns_per_shape`. Toggled via `set_ns_compile(enabled)`.
98
+
99
+ **Impact:** After warmup, NS iterations run as CUDA graphs β€” eliminates per-step compilation overhead and CPU-GPU synchronization.
100
+
101
+ ---
102
+
103
+ ## 8. Removed `small_param_numel_threshold` (`muon.py`)
104
+
105
+ **Before:** Small sharded DTensors (below threshold, default 65536) fell back to `distributed_muon()` which used per-param `full_tensor()` + redistribute.
106
+
107
+ **After:** All sharded DTensors go to `parallel()`. `distributed_muon()` is retained as a test-only reference implementation. Uneven shard splits (e.g., MoE gate weights with fewer rows than shard ranks) are handled inline via `full_tensor()` fallback within the batched distributed_muon path.
108
+
109
+ **Impact:** Simpler routing, no silent fallback to slower path.
110
+
111
+ ---
112
+
113
+ ## Summary Table
114
+
115
+ | Optimization | Location | Category | Kernel Launches Saved |
116
+ |---|---|---|---|
117
+ | Batched momentum | `core.py` | CPU + GPU | N per-param β†’ 2-3 batched |
118
+ | Buffer packing (gather) | `pipeline.py` | CPU + GPU | N cat+cast β†’ 1 cat+cast |
119
+ | Buffer packing (scatter) | `pipeline.py` | CPU + GPU | N cat β†’ 1 cat |
120
+ | Zero-copy scatter | `pipeline.py` | GPU memory | N alloc+copy β†’ 0 |
121
+ | Batched param update | `pipeline.py` | CPU + GPU | N update β†’ 1-2 batched |
122
+ | Parallel metadata cache | `muon.py` | CPU | Sort+index per step β†’ once |
123
+ | Expert expand cache | `muon.py` | CPU | N detach+select β†’ grad-only |
124
+ | NS compile + CUDA graph | `newton_schulz.py` | GPU | JIT warmup β†’ graph replay |
125
+ | Remove small_param_threshold | `muon.py` | Routing | Simpler, unified path |