Add optimization docs and update implementation guide [skip-build]
Browse files- docs/implementation.md +63 -19
- docs/optimizations.md +125 -0
docs/implementation.md
CHANGED
|
@@ -8,11 +8,12 @@ This document explains the internal architecture of the Muon optimizer for revie
|
|
| 8 |
2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
|
| 9 |
3. [Execution Paths](#execution-paths)
|
| 10 |
4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
|
| 11 |
-
5. [
|
| 12 |
-
6. [
|
| 13 |
-
7. [
|
| 14 |
-
8. [
|
| 15 |
-
9. [
|
|
|
|
| 16 |
|
| 17 |
---
|
| 18 |
|
|
@@ -33,17 +34,19 @@ Users must provide parameter groups with `use_muon=True/False` flags (via `get_d
|
|
| 33 |
|
| 34 |
```
|
| 35 |
_step_muon(group)
|
|
|
|
|
|
|
|
|
|
| 36 |
|
|
| 37 |
+-- DTensor, all Replicate placements --> base() (no sharding)
|
| 38 |
-
+-- DTensor, numel <= threshold --> distributed_muon() (small param fallback)
|
| 39 |
+-- DTensor, sharded --> parallel() (pipelined all-to-all)
|
| 40 |
+-- plain Tensor --> base() (single device)
|
| 41 |
```
|
| 42 |
|
| 43 |
Parameters are classified by their DTensor placements:
|
| 44 |
- **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon.
|
| 45 |
-
- **
|
| 46 |
-
-
|
| 47 |
|
| 48 |
## Execution Paths
|
| 49 |
|
|
@@ -51,9 +54,9 @@ Parameters are classified by their DTensor placements:
|
|
| 51 |
|
| 52 |
Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping.
|
| 53 |
|
| 54 |
-
### distributed_muon() — Full Gather
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
### parallel() — Pipelined All-to-All
|
| 59 |
|
|
@@ -171,6 +174,47 @@ Inverse of gather:
|
|
| 171 |
|
| 172 |
Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
## Distributed Utilities
|
| 175 |
|
| 176 |
**File:** `distributed/utils.py`
|
|
@@ -181,7 +225,7 @@ These utilities solve the problem of mapping from a DTensor's arbitrary sharding
|
|
| 181 |
|
| 182 |
Given a DTensor's placements and device mesh, this function:
|
| 183 |
|
| 184 |
-
1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard`
|
| 185 |
2. **Permutes** the mesh accordingly.
|
| 186 |
3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh.
|
| 187 |
4. **Creates** a ProcessGroup for the current rank's shard mesh.
|
|
@@ -214,7 +258,7 @@ def _is_shard(placement):
|
|
| 214 |
|
| 215 |
**File:** `newton_schulz.py`
|
| 216 |
|
| 217 |
-
`_zeropower_via_newtonschulz5()` computes the
|
| 218 |
|
| 219 |
Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.
|
| 220 |
|
|
@@ -248,15 +292,15 @@ Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimi
|
|
| 248 |
|
| 249 |
| File | Lines | Purpose |
|
| 250 |
|------|-------|---------|
|
| 251 |
-
| `muon.py` | ~
|
| 252 |
-
| `pipeline.py` | ~
|
| 253 |
| `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
|
| 254 |
-
| `core.py` | ~
|
| 255 |
| `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
|
| 256 |
-
| `newton_schulz.py` | ~
|
| 257 |
-
| `matmul_transpose_triton.py` | ~
|
| 258 |
-
| `qk_clip.py` | ~
|
| 259 |
-
| `adamw.py` | ~
|
| 260 |
|
| 261 |
### Dependency Graph
|
| 262 |
|
|
|
|
| 8 |
2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
|
| 9 |
3. [Execution Paths](#execution-paths)
|
| 10 |
4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
|
| 11 |
+
5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys)
|
| 12 |
+
6. [Distributed Utilities](#distributed-utilities)
|
| 13 |
+
7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization)
|
| 14 |
+
8. [QK Clipping](#qk-clipping)
|
| 15 |
+
9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters)
|
| 16 |
+
10. [Source File Map](#source-file-map)
|
| 17 |
|
| 18 |
---
|
| 19 |
|
|
|
|
| 34 |
|
| 35 |
```
|
| 36 |
_step_muon(group)
|
| 37 |
+
|
|
| 38 |
+
+-- momentum update (batched _foreach_* ops)
|
| 39 |
+
+-- _expand_expert_params() -- 3D expert params β per-expert 2D views (cached)
|
| 40 |
|
|
| 41 |
+-- DTensor, all Replicate placements --> base() (no sharding)
|
|
|
|
| 42 |
+-- DTensor, sharded --> parallel() (pipelined all-to-all)
|
| 43 |
+-- plain Tensor --> base() (single device)
|
| 44 |
```
|
| 45 |
|
| 46 |
Parameters are classified by their DTensor placements:
|
| 47 |
- **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon.
|
| 48 |
+
- **Sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below.
|
| 49 |
+
- `distributed_muon()` exists as a **test-only reference implementation** for correctness verification.
|
| 50 |
|
| 51 |
## Execution Paths
|
| 52 |
|
|
|
|
| 54 |
|
| 55 |
Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping.
|
| 56 |
|
| 57 |
+
### distributed_muon() — Full Gather (test-only)
|
| 58 |
|
| 59 |
+
Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy — not used in production.
|
| 60 |
|
| 61 |
### parallel() — Pipelined All-to-All
|
| 62 |
|
|
|
|
| 174 |
|
| 175 |
Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
|
| 176 |
|
| 177 |
+
## MoE Expert Weight Support (`expert_keys`)
|
| 178 |
+
|
| 179 |
+
**File:** `muon.py` — `_expand_expert_params()`
|
| 180 |
+
|
| 181 |
+
MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling.
|
| 182 |
+
|
| 183 |
+
### Configuration
|
| 184 |
+
|
| 185 |
+
Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
params = get_default_muon_param_groups(model, expert_keys=["experts"])
|
| 189 |
+
optim = Muon(params, expert_keys=["experts"], ...)
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration.
|
| 193 |
+
|
| 194 |
+
### How It Works
|
| 195 |
+
|
| 196 |
+
`_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`:
|
| 197 |
+
|
| 198 |
+
1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back.
|
| 199 |
+
2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split).
|
| 200 |
+
3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** — the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding.
|
| 201 |
+
|
| 202 |
+
### Placement-Agnostic Design
|
| 203 |
+
|
| 204 |
+
The expansion logic does not care *why* a dimension is sharded — only whether it's on dim 0 (consumed by split) or not (preserved on submesh):
|
| 205 |
+
|
| 206 |
+
| Original Placement | After Expansion |
|
| 207 |
+
|-------------------|-----------------|
|
| 208 |
+
| `Shard(0)` (EP) | Consumed by split → plain tensor |
|
| 209 |
+
| `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh → 2D DTensor |
|
| 210 |
+
| `Shard(2)` (TP row-wise) | `Shard(1)` on submesh → 2D DTensor |
|
| 211 |
+
| `Replicate` | Ignored (not a shard) |
|
| 212 |
+
| `_StridedShard(0)` (EFSDP) | Consumed by split → plain tensor |
|
| 213 |
+
|
| 214 |
+
After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`.
|
| 215 |
+
|
| 216 |
+
For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md).
|
| 217 |
+
|
| 218 |
## Distributed Utilities
|
| 219 |
|
| 220 |
**File:** `distributed/utils.py`
|
|
|
|
| 225 |
|
| 226 |
Given a DTensor's placements and device mesh, this function:
|
| 227 |
|
| 228 |
+
1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first).
|
| 229 |
2. **Permutes** the mesh accordingly.
|
| 230 |
3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh.
|
| 231 |
4. **Creates** a ProcessGroup for the current rank's shard mesh.
|
|
|
|
| 258 |
|
| 259 |
**File:** `newton_schulz.py`
|
| 260 |
|
| 261 |
+
`_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method — quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support.
|
| 262 |
|
| 263 |
Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.
|
| 264 |
|
|
|
|
| 292 |
|
| 293 |
| File | Lines | Purpose |
|
| 294 |
|------|-------|---------|
|
| 295 |
+
| `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching |
|
| 296 |
+
| `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) |
|
| 297 |
| `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
|
| 298 |
+
| `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping |
|
| 299 |
| `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
|
| 300 |
+
| `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph |
|
| 301 |
+
| `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul |
|
| 302 |
+
| `qk_clip.py` | ~135 | QK logit clipping |
|
| 303 |
+
| `adamw.py` | ~170 | Fused AdamW for non-Muon params |
|
| 304 |
|
| 305 |
### Dependency Graph
|
| 306 |
|
docs/optimizations.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Performance Optimizations (vs. main)
|
| 2 |
+
|
| 3 |
+
Summary of optimizations on branch `perf/pipelined-distributed-muon-clean` relative to `main`.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. Batched Momentum (`core.py`)
|
| 8 |
+
|
| 9 |
+
**Before:** Per-param `update_g()` β one `torch.add` + optional `torch.add_` per parameter.
|
| 10 |
+
|
| 11 |
+
**After:** `_batch_pre_ortho()` β `_foreach_mul_`, `_foreach_add_` on lists of local tensors (unwrapped from DTensor). Single fused kernel per batch instead of N individual kernels.
|
| 12 |
+
|
| 13 |
+
**Impact:** Eliminates N per-param Python-loop overhead + N small kernel launches. Scales with parameter count.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 2. Pipeline Buffer Packing (`pipeline.py`)
|
| 18 |
+
|
| 19 |
+
### Gather send buffer
|
| 20 |
+
|
| 21 |
+
**Before:** Per-param `.to(COMM_DTYPE).contiguous()` followed by per-destination `append` to list, then `torch.cat` on the per-dst lists.
|
| 22 |
+
|
| 23 |
+
**After:** Collect all grad slices in destination order in a single pass, then one `torch.cat` call. Avoids intermediate per-destination lists and redundant dtype conversions.
|
| 24 |
+
|
| 25 |
+
### Scatter send buffer
|
| 26 |
+
|
| 27 |
+
**Before:** Per-param, per-destination-rank: index `u_full[indices].flatten()`, append to per-dst list, then flatten+cat.
|
| 28 |
+
|
| 29 |
+
**After:** Cache `u_full` conversions (avoid redundant `.to()` per dst_rank). Collect all slices in dst order in one pass, single `torch.cat`.
|
| 30 |
+
|
| 31 |
+
**Impact:** Fewer kernel launches, less Python overhead, reduced intermediate allocations.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 3. Zero-Copy Scatter (`pipeline.py`)
|
| 36 |
+
|
| 37 |
+
**Before:** `_launch_scatter` pre-allocates `torch.empty_like(p.to_local())` for every param. `_complete_scatter` copies from recv_buf into these pre-allocated tensors via `copy_()`.
|
| 38 |
+
|
| 39 |
+
**After:** `_complete_scatter` assigns **views** into `recv_buf` directly (via `recv_buf.narrow(...).view_as(...)`). No pre-allocation, no copy. The recv_buf storage stays alive through the views until `_update_params` consumes them.
|
| 40 |
+
|
| 41 |
+
**Impact:** Eliminates N `empty_like` allocations + N `copy_` kernel launches per scatter stage.
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## 4. Batched Parameter Update (`pipeline.py`)
|
| 46 |
+
|
| 47 |
+
**Before:** Per-param loop calling `update_p()` (which unwraps DTensor, applies weight decay, applies update individually).
|
| 48 |
+
|
| 49 |
+
**After:** Batched using `_foreach_mul_` (weight decay) and `_foreach_add_` (Muon update), grouped by `adjusted_lr` to preserve float32 alpha precision. Single kernel per group instead of per param.
|
| 50 |
+
|
| 51 |
+
**Impact:** Reduces N per-param kernel launches to 1-2 batched kernel launches.
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## 5. Parallel Metadata Caching (`muon.py`)
|
| 56 |
+
|
| 57 |
+
**Before:** `init_state_and_assign_params()` called every step β sorts params by FLOP cost, assigns ownership via round-robin, precomputes per-rank indices/numels for all-to-all.
|
| 58 |
+
|
| 59 |
+
**After:** `_parallel_cache` keyed by `tuple(names)`. First call computes and caches `ordered_names`, `name_to_state`, `rank`, `chunk_size`. Subsequent calls reuse cached metadata, only rebuilding `param_to_state` with current `id(p)` keys (since param objects are stable but ids may change for QK clip updates).
|
| 60 |
+
|
| 61 |
+
**Impact:** Eliminates repeated sorting, mesh construction, and index precomputation on every step.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## 6. Expert Param Expansion Caching (`muon.py`)
|
| 66 |
+
|
| 67 |
+
**Before:** `_expand_expert_params()` called every step β for each expert param `(E, out, in)`, creates E `nn.Parameter` wrappers (triggers `aten::detach`), indexes data and grad (`aten::select`), and wraps in DTensor for TP.
|
| 68 |
+
|
| 69 |
+
**After:** `_expert_expand_cache` keyed by `tuple(id(p) for p in params)`. Cold path runs `_expand_expert_params` once and caches:
|
| 70 |
+
|
| 71 |
+
- `expanded_names` / `expanded_params` β the nn.Parameter wrappers with stable data views
|
| 72 |
+
- `grad_info` β per-expert-group metadata (orig param index, num experts, expanded start index, DTensor flag, TP mesh/placements)
|
| 73 |
+
|
| 74 |
+
Hot path reuses cached nn.Parameter objects (data views are stable since optimizer updates happen in-place on the same storage). Only updates `.grad` on each cached expert param by slicing the current step's gradient.
|
| 75 |
+
|
| 76 |
+
**Eliminated on hot path:**
|
| 77 |
+
|
| 78 |
+
- `nn.Parameter()` construction β removes `aten::detach`
|
| 79 |
+
- `local_data[i]` data slicing β removes half of `aten::select` + `aten::as_strided`
|
| 80 |
+
- `DTensor.from_local()` for data β only needed for grad now
|
| 81 |
+
- `is_expert_param()` name matching per step
|
| 82 |
+
|
| 83 |
+
**Still required per step:**
|
| 84 |
+
|
| 85 |
+
- `local_grad[i]` β grad tensor changes each step (nesterov)
|
| 86 |
+
- `DTensor.from_local(slice_grad, ...)` β for TP expert grads
|
| 87 |
+
- `p.grad = None` β freeing original 3D grad storage
|
| 88 |
+
|
| 89 |
+
**Impact:** ~8ms CPU overhead reduction per step at production scale (64 GPUs, 48 local experts).
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## 7. Newton-Schulz Compile + CUDA Graph (`newton_schulz.py`)
|
| 94 |
+
|
| 95 |
+
**Before:** `_zeropower_via_newtonschulz5()` called directly every time.
|
| 96 |
+
|
| 97 |
+
**After:** `zeropower_via_newtonschulz5()` wrapper with per-shape `torch.compile` caching + CUDA graph (`triton.cudagraphs=True`). Each unique shape gets its own compiled function stored in `_ns_per_shape`. Toggled via `set_ns_compile(enabled)`.
|
| 98 |
+
|
| 99 |
+
**Impact:** After warmup, NS iterations run as CUDA graphs β eliminates per-step compilation overhead and CPU-GPU synchronization.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## 8. Removed `small_param_numel_threshold` (`muon.py`)
|
| 104 |
+
|
| 105 |
+
**Before:** Small sharded DTensors (below threshold, default 65536) fell back to `distributed_muon()` which used per-param `full_tensor()` + redistribute.
|
| 106 |
+
|
| 107 |
+
**After:** All sharded DTensors go to `parallel()`. `distributed_muon()` is retained as a test-only reference implementation. Uneven shard splits (e.g., MoE gate weights with fewer rows than shard ranks) are handled inline via `full_tensor()` fallback within the batched distributed_muon path.
|
| 108 |
+
|
| 109 |
+
**Impact:** Simpler routing, no silent fallback to slower path.
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
## Summary Table
|
| 114 |
+
|
| 115 |
+
| Optimization | Location | Category | Kernel Launches Saved |
|
| 116 |
+
|---|---|---|---|
|
| 117 |
+
| Batched momentum | `core.py` | CPU + GPU | N per-param β 2-3 batched |
|
| 118 |
+
| Buffer packing (gather) | `pipeline.py` | CPU + GPU | N cat+cast β 1 cat+cast |
|
| 119 |
+
| Buffer packing (scatter) | `pipeline.py` | CPU + GPU | N cat β 1 cat |
|
| 120 |
+
| Zero-copy scatter | `pipeline.py` | GPU memory | N alloc+copy β 0 |
|
| 121 |
+
| Batched param update | `pipeline.py` | CPU + GPU | N update β 1-2 batched |
|
| 122 |
+
| Parallel metadata cache | `muon.py` | CPU | Sort+index per step β once |
|
| 123 |
+
| Expert expand cache | `muon.py` | CPU | N detach+select β grad-only |
|
| 124 |
+
| NS compile + CUDA graph | `newton_schulz.py` | GPU | JIT warmup β graph replay |
|
| 125 |
+
| Remove small_param_threshold | `muon.py` | Routing | Simpler, unified path |
|