Kernels
This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. .github/actionlint.yaml +0 -3
  2. .github/workflows/build-and-commit.yml +0 -120
  3. .github/workflows/pre-commit.yml +0 -30
  4. .github/workflows/push-to-hf.yml +0 -40
  5. .gitignore +0 -21
  6. .pre-commit-config.yaml +0 -33
  7. CLAUDE.md +0 -108
  8. README.md +4 -75
  9. _typos.toml +0 -3
  10. build.toml +14 -24
  11. build/torch210-cxx11-cu126-x86_64-linux/adamw.py +0 -271
  12. build/torch210-cxx11-cu126-x86_64-linux/async_utils.py +0 -77
  13. build/torch210-cxx11-cu126-x86_64-linux/core.py +0 -219
  14. build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +0 -206
  15. build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +0 -232
  16. build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +0 -122
  17. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +0 -3
  18. build/torch210-cxx11-cu126-x86_64-linux/muon.py +0 -1068
  19. build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +0 -240
  20. build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py +0 -26
  21. build/torch210-cxx11-cu126-x86_64-linux/pipeline.py +0 -468
  22. build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +0 -198
  23. build/torch210-cxx11-cu128-x86_64-linux/adamw.py +0 -271
  24. build/torch210-cxx11-cu128-x86_64-linux/async_utils.py +0 -77
  25. build/torch210-cxx11-cu128-x86_64-linux/core.py +0 -219
  26. build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +0 -206
  27. build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +0 -232
  28. build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +0 -122
  29. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +0 -3
  30. build/torch210-cxx11-cu128-x86_64-linux/muon.py +0 -1068
  31. build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +0 -240
  32. build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py +0 -26
  33. build/torch210-cxx11-cu128-x86_64-linux/pipeline.py +0 -468
  34. build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +0 -198
  35. build/torch210-cxx11-cu130-x86_64-linux/adamw.py +0 -271
  36. build/torch210-cxx11-cu130-x86_64-linux/async_utils.py +0 -77
  37. build/torch210-cxx11-cu130-x86_64-linux/core.py +0 -219
  38. build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +0 -206
  39. build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +0 -232
  40. build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +0 -122
  41. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +0 -3
  42. build/torch210-cxx11-cu130-x86_64-linux/muon.py +0 -1068
  43. build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +0 -240
  44. build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py +0 -26
  45. build/torch210-cxx11-cu130-x86_64-linux/pipeline.py +0 -468
  46. build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +0 -198
  47. build/torch210-cxx11-rocm70-x86_64-linux/adamw.py +0 -271
  48. build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py +0 -77
  49. build/torch210-cxx11-rocm70-x86_64-linux/core.py +0 -219
  50. build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +0 -206
.github/actionlint.yaml DELETED
@@ -1,3 +0,0 @@
1
- self-hosted-runner:
2
- labels:
3
- - docker-builder-01
 
 
 
 
.github/workflows/build-and-commit.yml DELETED
@@ -1,120 +0,0 @@
1
- name: Nix build and commit
2
-
3
- on:
4
- pull_request:
5
- types: [opened, synchronize, reopened]
6
- workflow_dispatch:
7
-
8
- permissions:
9
- contents: write
10
-
11
- jobs:
12
- check-commit:
13
- runs-on: ubuntu-latest
14
- outputs:
15
- skip: ${{ steps.check.outputs.skip }}
16
- steps:
17
- - uses: actions/checkout@v4
18
- with:
19
- fetch-depth: 0
20
- - id: check
21
- run: |
22
- if [ "${{ github.event_name }}" = "pull_request" ]; then
23
- msg=$(git log -1 --pretty=%B "${{ github.event.pull_request.head.sha }}")
24
- else
25
- msg="manual dispatch"
26
- fi
27
- echo "Commit message: $msg"
28
- if echo "$msg" | grep -q '\[skip-build\]'; then
29
- echo "skip=true" >> "$GITHUB_OUTPUT"
30
- else
31
- echo "skip=false" >> "$GITHUB_OUTPUT"
32
- fi
33
-
34
- build_and_commit:
35
- needs: check-commit
36
- if: needs.check-commit.outputs.skip == 'false'
37
- runs-on: docker-builder-01
38
- steps:
39
- - name: Show disk usage
40
- run: df -h
41
-
42
- - name: Notify build start on Slack
43
- id: slack_start
44
- run: |
45
- msg="*Build started* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
46
- response=$(curl -s -X POST \
47
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
48
- -H "Content-type: application/json; charset=utf-8" \
49
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\"}" \
50
- https://slack.com/api/chat.postMessage)
51
- ts=$(echo "$response" | jq -r '.ts')
52
- echo "thread_ts=$ts" >> "$GITHUB_OUTPUT"
53
- echo "$response"
54
-
55
- - name: Checkout repository
56
- uses: actions/checkout@v4
57
- with:
58
- fetch-depth: 0
59
- lfs: true
60
- ref: ${{ github.head_ref || github.ref }}
61
-
62
- - name: Install Nix
63
- uses: cachix/install-nix-action@v31
64
-
65
- - name: Setup huggingface cachix
66
- uses: cachix/cachix-action@v15
67
- with:
68
- name: huggingface
69
-
70
- - name: Clean build directory
71
- run: |
72
- rm -rf build
73
-
74
- - name: Build with Nix
75
- run: |
76
- nix run .#build-and-copy \
77
- --override-input kernel-builder github:huggingface/kernel-builder \
78
- --max-jobs 8 \
79
- -j 8 \
80
- -L
81
-
82
- - name: List built binaries
83
- run: |
84
- ls build
85
-
86
- - name: Commit build artifact
87
- run: |
88
- git config user.name "github-actions[bot]"
89
- git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
90
- git add build/*
91
- git commit -m "Add built binary [skip-build]"
92
-
93
- - name: Push changes
94
- run: |
95
- git push origin HEAD:"$HEAD_REF"
96
- env:
97
- HEAD_REF: ${{ github.head_ref || github.ref }}
98
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
99
-
100
- - name: Notify success on Slack (thread)
101
- if: success()
102
- run: |
103
- ts="${{ steps.slack_start.outputs.thread_ts }}"
104
- msg="*Build succeeded* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
105
- curl -s -X POST \
106
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
107
- -H "Content-type: application/json; charset=utf-8" \
108
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \
109
- https://slack.com/api/chat.postMessage
110
-
111
- - name: Notify failure on Slack (thread)
112
- if: failure()
113
- run: |
114
- ts="${{ steps.slack_start.outputs.thread_ts }}"
115
- msg="*Build failed* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
116
- curl -s -X POST \
117
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
118
- -H "Content-type: application/json; charset=utf-8" \
119
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \
120
- https://slack.com/api/chat.postMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/pre-commit.yml DELETED
@@ -1,30 +0,0 @@
1
- name: pre-commit
2
-
3
- on:
4
- pull_request:
5
- push:
6
- branches: [ main, master ]
7
-
8
- jobs:
9
- run-pre-commit:
10
- runs-on: ubuntu-latest
11
- permissions:
12
- contents: read
13
- pull-requests: read
14
- steps:
15
- - uses: actions/checkout@v4
16
-
17
- - uses: actions/setup-python@v5
18
- with:
19
- python-version: "3.11"
20
-
21
- - name: Cache pre-commit
22
- uses: actions/cache@v4
23
- with:
24
- path: ~/.cache/pre-commit
25
- key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
26
- restore-keys: |
27
- pre-commit-${{ runner.os }}-
28
-
29
- - name: Run pre-commit
30
- uses: pre-commit/action@v3.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/push-to-hf.yml DELETED
@@ -1,40 +0,0 @@
1
- name: Push to HF Repo
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
- workflow_dispatch:
8
-
9
- jobs:
10
- push_to_hf:
11
- runs-on: ubuntu-latest
12
- steps:
13
- # 1. Checkout the repo
14
- - name: Checkout repository
15
- uses: actions/checkout@v4
16
- with:
17
- fetch-depth: 0
18
- - name: Install Git LFS
19
- run: |
20
- git lfs install
21
- git lfs fetch --all
22
- git lfs pull
23
- # 2. Set up Git
24
- - name: Configure Git
25
- run: |
26
- git config user.name "MotifTech"
27
- git config user.email "huggingface@motiftech.io"
28
-
29
- # 3. Add HF remote
30
- - name: Add Hugging Face remote
31
- run: |
32
- git remote add hf https://huggingface.co/Motif-Technologies/optimizer
33
- git fetch hf || true
34
-
35
- # 4. Push to HF repo
36
- - name: Push to Hugging Face
37
- env:
38
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
39
- run: |
40
- git push "https://hf_token:${HF_TOKEN}@huggingface.co/Motif-Technologies/optimizer" HEAD:main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,21 +0,0 @@
1
- __pycache__
2
- .idea
3
- .DS_Store
4
- *.egg-info
5
- outputs
6
- dist/*
7
- .vscode
8
-
9
- # data
10
- data
11
- out
12
- wandb
13
-
14
- torchtitan/datasets/**/*.model
15
- torchtitan/experiments/flux/assets/*
16
-
17
- # temp files
18
- *.log
19
- error.json
20
- _remote_module_non_scriptable.py
21
- .git_disabled/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.pre-commit-config.yaml DELETED
@@ -1,33 +0,0 @@
1
- default_install_hook_types:
2
- - pre-commit
3
- - commit-msg
4
- default_stages:
5
- - pre-commit # Run locally
6
- - manual # Run in CI
7
- exclude: '(build|result)/.*|__pycache__/.*|.*\.(png|html)$'
8
- repos:
9
- - repo: https://github.com/google/yapf
10
- rev: v0.43.0
11
- hooks:
12
- - id: yapf
13
- args: [--in-place, --verbose]
14
- - repo: https://github.com/crate-ci/typos
15
- rev: v1.34.0
16
- hooks:
17
- - id: typos
18
- exclude: '.gitattributes'
19
- - repo: https://github.com/PyCQA/isort
20
- rev: 6.0.1
21
- hooks:
22
- - id: isort
23
- - repo: https://github.com/pre-commit/mirrors-clang-format
24
- rev: v20.1.3
25
- hooks:
26
- - id: clang-format
27
- types_or: [c++, cuda]
28
- args: [--style=file, --verbose]
29
- - repo: https://github.com/jackdewinter/pymarkdown
30
- rev: v0.9.29
31
- hooks:
32
- - id: pymarkdown
33
- args: [fix]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CLAUDE.md DELETED
@@ -1,108 +0,0 @@
1
- # CLAUDE.md
2
-
3
- This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
-
5
- ## Project Overview
6
-
7
- Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).
8
-
9
- ## Commands
10
-
11
- ### Lint & Format
12
-
13
- ```bash
14
- pre-commit run --all-files # Run all pre-commit hooks
15
- pre-commit run isort --all-files # Run a specific hook (e.g., isort)
16
- ```
17
-
18
- Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).
19
-
20
- ### Tests
21
-
22
- Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0.
23
-
24
- ```bash
25
- cd test && ./run_test.sh
26
- # Equivalent to:
27
- cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py
28
- ```
29
-
30
- Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation).
31
-
32
- ### Build
33
-
34
- Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`.
35
-
36
- ### Commit Convention
37
-
38
- **Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches.
39
-
40
- ## Architecture
41
-
42
- ### Source Layout
43
-
44
- ```
45
- torch-ext/optimizer/
46
- ├── __init__.py # Public API: exports Muon
47
- ├── muon.py # Muon optimizer class (~430 lines)
48
- ├── newton_schulz.py # Newton-Schulz iteration (~50 lines)
49
- ├── qk_clip.py # QK clipping for attention heads (~130 lines)
50
- ├── core.py # Shared state, helpers, param grouping (~110 lines)
51
- ├── pipeline.py # Async generator pipeline for parallel mode (~290 lines)
52
- ├── async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines)
53
- ├── adamw.py # Fused AdamW for non-Muon parameters (~160 lines)
54
- ├── matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines)
55
- └── distributed/
56
- └── utils.py # Shard mesh construction, DTensor slicing (~175 lines)
57
- ```
58
-
59
- ### Optimizer Modes
60
-
61
- The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:
62
-
63
- 1. **Base mode** (`base()`) — Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization.
64
- 2. **Distributed mode** (`distributed_muon()`) — Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback.
65
- 3. **Parallel mode** (`parallel()`) — Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature.
66
-
67
- ### Parallel Mode Pipeline
68
-
69
- The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through:
70
-
71
- ```
72
- build bufs + async all2all_gather → yield → wait + Newton-Schulz compute + async all2all_scatter → yield → wait + update_param
73
- ```
74
-
75
- The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield.
76
-
77
- `warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`.
78
-
79
- For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md).
80
-
81
- ### Key Abstractions
82
-
83
- - **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) — Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default.
84
- - **`_muon_state` dataclass** (`core.py`) — Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state.
85
- - **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.
86
- - **`run_pipeline()`** (`async_utils.py`) — Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.
87
- - **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) — Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+).
88
- - **Newton-Schulz iteration** (`newton_schulz.py`) — `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T.
89
- - **QK Clipping** (`qk_clip.py`) — Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`.
90
- - **Fused AdamW** (`adamw.py`) — Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.
91
-
92
- ### Dependency Graph
93
-
94
- ```
95
- matmul_transpose_triton.py (leaf)
96
-
97
- newton_schulz.py (leaf + triton)
98
-
99
- core.py ──── qk_clip.py (leaf, distributed/utils)
100
- │ │ │
101
- │ pipeline.py ─── async_utils.py
102
- │ │
103
- │ adamw.py
104
- │ │
105
- muon.py (all above)
106
-
107
- __init__.py
108
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,7 +1,6 @@
1
  ---
2
  tags:
3
- - kernels
4
- license: apache-2.0
5
  ---
6
 
7
  # Optimizer
@@ -10,14 +9,8 @@ Optimizer is a python package that provides:
10
  - PyTorch implementation of recent optimizer algorithms
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
- ## Currently implemented
14
- - Parallel Muon with N-D sharding
15
- - [arxiv URL](https://arxiv.org/abs/2511.07464)
16
- - Supports **general N-D sharding configurations**
17
- - The implementation is not tied to any specific parallel strategy.
18
- - Verified from basic FSDP2 setups up to hybrid configurations such as
19
- **(2 TP + 2 DP-Replicate + 2 DP-Shard)**.
20
- - Verified configurations can be found in [test_muon.py](./test/test_muon.py)
21
 
22
  ## Usage
23
 
@@ -27,78 +20,14 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27
  from kernels import get_kernel
28
 
29
  optimizer = get_kernel("motif-technologies/optimizer")
30
- get_default_muon_param_groups = optimizer.muon.get_default_muon_param_groups
31
 
32
  model = None # your model here
33
  fsdp_model = FSDP(model)
34
 
35
- # muon, in nature, cannot use 1-d tensor
36
- # we provide helper function to group such tensors
37
- # you can use your own function, if necessary
38
- params = get_default_muon_param_groups(model) # user can write own is_muon_func, if necessary
39
-
40
  optim = optimizer.Muon(
41
- params,
42
  lr=0.01,
43
  momentum=0.9,
44
  weight_decay=1e-4,
45
  )
46
  ```
47
-
48
- ## Documentation
49
-
50
- - [Implementation Guide](./docs/implementation.md) — Detailed walkthrough of the internal architecture, parallel pipeline, distributed utilities, and QK clipping. Recommended for code reviewers and new contributors.
51
- - [PyTorch 2.10 TP Fix](./docs/pytorch-2.10-tp-fix.md) — Root cause analysis and fixes for `_StridedShard` compatibility with PyTorch 2.10+.
52
-
53
- ## Test
54
-
55
- - Check [test/README.md](./test/README.md) for how to run the tests.
56
-
57
- ## Pre-commit Hooks
58
-
59
- This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
60
-
61
- ### Setup
62
-
63
- 1. Install pre-commit:
64
-
65
- ```bash
66
- pip install pre-commit
67
- ```
68
-
69
- 2. Install the git hooks:
70
-
71
- ```bash
72
- pre-commit install
73
- ```
74
-
75
- Once installed, the configured hooks will run automatically on each commit.
76
-
77
- ### Included Hooks
78
-
79
- The following tools are run via pre-commit:
80
-
81
- - **[yapf](https://github.com/google/yapf)** – Python code formatter
82
- - **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
83
- - **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
84
- - **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
85
- - **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
86
- - **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
87
-
88
- ### Usage
89
-
90
- - Run all checks on the entire codebase:
91
-
92
- ```bash
93
- pre-commit run --all-files
94
- ```
95
-
96
- - Run a specific hook (example: isort):
97
-
98
- ```bash
99
- pre-commit run isort --all-files
100
- ```
101
-
102
- ### Test
103
-
104
- - There is a [simple unittest for Parallel Muon](./test/test_muon/README.md)
 
1
  ---
2
  tags:
3
+ - kernel
 
4
  ---
5
 
6
  # Optimizer
 
9
  - PyTorch implementation of recent optimizer algorithms
10
  - with support for parallelism techniques for efficient large-scale training.
11
 
12
+ ### Currently implemented
13
+ - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
 
 
 
 
 
 
14
 
15
  ## Usage
16
 
 
20
  from kernels import get_kernel
21
 
22
  optimizer = get_kernel("motif-technologies/optimizer")
 
23
 
24
  model = None # your model here
25
  fsdp_model = FSDP(model)
26
 
 
 
 
 
 
27
  optim = optimizer.Muon(
28
+ fsdp_model.parameters(),
29
  lr=0.01,
30
  momentum=0.9,
31
  weight_decay=1e-4,
32
  )
33
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
_typos.toml DELETED
@@ -1,3 +0,0 @@
1
- [default.extend-words]
2
- # Math notation used in docs/muon-clip.md (O subscript t, update step output)
3
- Ot = "Ot"
 
 
 
 
build.toml CHANGED
@@ -1,33 +1,23 @@
1
  [general]
2
  name = "optimizer"
3
- backends = [
4
- "cuda",
5
- "rocm",
6
- ]
7
 
8
  [torch]
9
  src = [
10
- "torch-ext/torch_binding.cpp",
11
- "torch-ext/torch_binding.h",
12
  ]
13
 
14
- [kernel.optimizer]
15
- backend = "cuda"
16
- depends = ["torch"]
17
- src = ["optimizer/dummy.cu"]
18
-
19
- [kernel.optimizer_rocm]
20
  backend = "rocm"
21
- rocm-archs = [
22
- "gfx906",
23
- "gfx908",
24
- "gfx90a",
25
- "gfx940",
26
- "gfx941",
27
- "gfx942",
28
- "gfx1030",
29
- "gfx1100",
30
- "gfx1101",
31
  ]
32
- depends = ["torch"]
33
- src = ["optimizer/dummy.cu"]
 
1
  [general]
2
  name = "optimizer"
3
+ universal = false
 
 
 
4
 
5
  [torch]
6
  src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
  ]
10
 
11
+ [kernel.activation]
 
 
 
 
 
12
  backend = "rocm"
13
+ src = [
14
+ "optimizer/dummy.cu",
15
+ ]
16
+ depends = [ "torch" ]
17
+
18
+ [kernel.activation_cuda]
19
+ backend = "cuda"
20
+ src = [
21
+ "optimizer/dummy.cu",
 
22
  ]
23
+ depends = [ "torch" ]
 
build/torch210-cxx11-cu126-x86_64-linux/adamw.py DELETED
@@ -1,271 +0,0 @@
1
- import logging
2
- from collections import defaultdict
3
- from typing import cast
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def fused_adamw(
13
- params: list[torch.Tensor],
14
- grads: list[torch.Tensor],
15
- exp_avgs: list[torch.Tensor],
16
- exp_avg_sqs: list[torch.Tensor],
17
- max_exp_avg_sqs: list[torch.Tensor],
18
- state_steps: list[torch.Tensor],
19
- amsgrad: bool,
20
- beta1: float,
21
- beta2: float,
22
- lr: float | torch.Tensor,
23
- weight_decay: float,
24
- eps: float,
25
- maximize: bool,
26
- ) -> None:
27
- if not params:
28
- return
29
-
30
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
31
- # treating it as a scalar.
32
- lr_dict: dict | None = ({
33
- lr.device: lr
34
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
35
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
36
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
37
- state_steps] # type: ignore[list-item]
38
- )
39
- for (device, _), (
40
- (
41
- device_params_,
42
- device_grads_,
43
- device_exp_avgs_,
44
- device_exp_avg_sqs_,
45
- device_max_exp_avg_sqs,
46
- device_state_steps_,
47
- ),
48
- _,
49
- ) in grouped_tensors.items():
50
- device_params = cast(list[torch.Tensor], device_params_)
51
- device_grads = cast(list[torch.Tensor], device_grads_)
52
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
53
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
54
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
55
-
56
- if lr_dict is not None and device not in lr_dict:
57
- lr_dict[device] = lr.to(
58
- device=device, non_blocking=True) # type: ignore[union-attr]
59
- lr = lr_dict[device]
60
- torch._foreach_add_(device_state_steps, 1)
61
- func = torch._fused_adamw_
62
- func(
63
- device_params,
64
- device_grads,
65
- device_exp_avgs,
66
- device_exp_avg_sqs,
67
- device_max_exp_avg_sqs, # type: ignore[arg-type]
68
- device_state_steps,
69
- amsgrad=amsgrad,
70
- lr=lr, # type: ignore[arg-type]
71
- beta1=beta1,
72
- beta2=beta2,
73
- weight_decay=weight_decay,
74
- eps=eps,
75
- maximize=maximize,
76
- )
77
-
78
-
79
- def _to_local(t):
80
- """Unwrap DTensor to local tensor for fused ops."""
81
- return t._local_tensor if isinstance(t, DTensor) else t
82
-
83
-
84
- # ---------------------------------------------------------------------------
85
- # Caches for eliminating per-step Python overhead.
86
- #
87
- # Placement grouping and tensor list assembly are identical every step
88
- # (params don't change placement, moment/step tensors are the same objects
89
- # after initialisation). We cache them keyed by id() of the param list
90
- # stored in param_groups (stable across steps).
91
- #
92
- # Only gradients change each step and must be collected fresh.
93
- # ---------------------------------------------------------------------------
94
-
95
- # id(group["params"]) → dict[placement_key, list[param]]
96
- _placement_cache: dict[int, dict[tuple, list]] = {}
97
-
98
- # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
- _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
-
101
-
102
- def _step_adamw_params_slow(optimizer_state, params, group):
103
- """Uncached fallback for the rare case where some params lack grads."""
104
- params_with_grads = []
105
- grads = []
106
- moment1 = []
107
- moment2 = []
108
- state_steps = []
109
-
110
- for p in params:
111
- g = p.grad
112
- if g is None:
113
- continue
114
- state = optimizer_state[p]
115
- params_with_grads.append(_to_local(p))
116
- grads.append(_to_local(g))
117
- if "step" not in state:
118
- state["step"] = torch.zeros((),
119
- dtype=torch.float32,
120
- device=p.device)
121
- state["moment1"] = torch.zeros_like(g)
122
- state["moment2"] = torch.zeros_like(g)
123
- moment1.append(_to_local(state["moment1"]))
124
- moment2.append(_to_local(state["moment2"]))
125
- if not isinstance(state["step"], torch.Tensor):
126
- state["step"] = torch.tensor(state["step"],
127
- dtype=torch.float32,
128
- device=p.device)
129
- state_steps.append(state["step"])
130
-
131
- if not params_with_grads:
132
- return
133
-
134
- lr = group["lr"]
135
- beta1, beta2 = group["adamw_betas"]
136
- eps = group["adamw_eps"]
137
- weight_decay = group["weight_decay"]
138
-
139
- fused_adamw(
140
- params_with_grads,
141
- grads,
142
- moment1,
143
- moment2,
144
- [],
145
- state_steps,
146
- amsgrad=False,
147
- beta1=beta1,
148
- beta2=beta2,
149
- lr=lr,
150
- weight_decay=weight_decay,
151
- eps=eps,
152
- maximize=False,
153
- )
154
-
155
-
156
- def step_adamw_params(optimizer_state, params, group):
157
- """Run fused AdamW on a list of parameters sharing the same placement.
158
-
159
- After the first call, cached tensor lists (params_local, moment1,
160
- moment2, state_steps) are reused — only gradients are collected fresh.
161
-
162
- Args:
163
- optimizer_state: The optimizer's state dict (self.state in Muon).
164
- params: List of parameters to update.
165
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
- """
167
- # Collect grads — the only thing that changes each step.
168
- with record_function("adamw::collect_grads"):
169
- grads = []
170
- for p in params:
171
- g = p.grad
172
- if g is None:
173
- # Rare: fall back to slow path that filters per-param.
174
- _step_adamw_params_slow(optimizer_state, params, group)
175
- return
176
- grads.append(_to_local(g))
177
-
178
- tensor_key = id(params)
179
- if tensor_key not in _tensor_cache:
180
- with record_function("adamw::init_tensor_cache"):
181
- params_local = []
182
- moment1 = []
183
- moment2 = []
184
- state_steps = []
185
-
186
- for p in params:
187
- state = optimizer_state[p]
188
- params_local.append(_to_local(p))
189
- if "step" not in state:
190
- state["step"] = torch.zeros((),
191
- dtype=torch.float32,
192
- device=p.device)
193
- state["moment1"] = torch.zeros_like(p.grad)
194
- state["moment2"] = torch.zeros_like(p.grad)
195
- moment1.append(_to_local(state["moment1"]))
196
- moment2.append(_to_local(state["moment2"]))
197
- if not isinstance(state["step"], torch.Tensor):
198
- state["step"] = torch.tensor(state["step"],
199
- dtype=torch.float32,
200
- device=p.device)
201
- state_steps.append(state["step"])
202
-
203
- _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
- state_steps)
205
-
206
- params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
-
208
- lr = group["lr"]
209
- beta1, beta2 = group["adamw_betas"]
210
- eps = group["adamw_eps"]
211
- weight_decay = group["weight_decay"]
212
-
213
- with record_function("adamw::fused_adamw"):
214
- fused_adamw(
215
- params_local,
216
- grads,
217
- moment1,
218
- moment2,
219
- [],
220
- state_steps,
221
- amsgrad=False,
222
- beta1=beta1,
223
- beta2=beta2,
224
- lr=lr,
225
- weight_decay=weight_decay,
226
- eps=eps,
227
- maximize=False,
228
- )
229
-
230
-
231
- def step_adamw(optimizer_state, group):
232
- """Dispatch AdamW step, grouping parameters by type and placement.
233
-
234
- Placement grouping is cached after the first call since params never
235
- change their placement between steps.
236
-
237
- Args:
238
- optimizer_state: The optimizer's state dict (self.state in Muon).
239
- group: Parameter group dict.
240
- """
241
- params = group["params"]
242
- placement_key = id(params)
243
-
244
- if placement_key not in _placement_cache:
245
- with record_function("adamw::group_by_placement"):
246
- placement_to_params: dict[tuple,
247
- list[torch.Tensor]] = defaultdict(list)
248
- for p in params:
249
- match p:
250
- case DTensor():
251
- logger.debug(
252
- "[AdamW] DTensor param: shape=%s, placements=%s, "
253
- "mesh=%s, grad=%s", p.shape, p.placements,
254
- p.device_mesh.mesh_dim_names,
255
- p.grad.shape if p.grad is not None else None)
256
- placement_to_params[tuple(
257
- [p.placements, p.device_mesh])].append(p)
258
- case torch.Tensor():
259
- logger.debug(
260
- "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
- p.grad.shape if p.grad is not None else None)
262
- placement_to_params[tuple([torch.Tensor,
263
- None])].append(p)
264
-
265
- logger.debug("[AdamW] %d placement groups, %d total params",
266
- len(placement_to_params), len(params))
267
-
268
- _placement_cache[placement_key] = dict(placement_to_params)
269
-
270
- for group_params in _placement_cache[placement_key].values():
271
- step_adamw_params(optimizer_state, group_params, group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/async_utils.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- logger = logging.getLogger(__name__)
5
-
6
-
7
- class _Task:
8
- """Internal: wraps a generator, advances one yield at a time."""
9
-
10
- def __init__(self, generator: Generator[None, None, None], index: int):
11
- self._generator = generator
12
- self._index = index
13
- self._steps_completed = 0
14
- self.step() # run to first yield
15
-
16
- def step(self) -> bool:
17
- try:
18
- next(self._generator)
19
- self._steps_completed += 1
20
- logger.debug("pipeline[%d] completed stage %d", self._index,
21
- self._steps_completed)
22
- return True
23
- except StopIteration:
24
- logger.debug("pipeline[%d] finished after %d stages", self._index,
25
- self._steps_completed)
26
- return False
27
-
28
- def close(self):
29
- self._generator.close()
30
-
31
-
32
- def run_pipeline(
33
- pipelines: Generator[Generator[None, None, None], None, None],
34
- max_concurrent: int,
35
- ) -> None:
36
- """Run generator-based pipelines with bounded concurrency.
37
-
38
- Each pipeline is a generator that yields at stage boundaries.
39
- The runtime interleaves pipelines so communication and computation
40
- overlap across chunks.
41
- """
42
- if max_concurrent <= 0:
43
- raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
-
45
- have_new = True
46
- task_index = 0
47
- previous_tasks: list[_Task] = []
48
-
49
- try:
50
- while have_new or previous_tasks:
51
- running_tasks: list[_Task] = []
52
-
53
- # Admit one new pipeline per iteration (staggered admission).
54
- # Admitting one at a time ensures that while chunk N does NS
55
- # compute on the default stream, chunk N+1's NCCL all-to-all
56
- # runs concurrently on the NCCL stream — creating real
57
- # communication/computation overlap on the GPU.
58
- if have_new and len(previous_tasks) < max_concurrent:
59
- try:
60
- gen = next(pipelines)
61
- task = _Task(gen, task_index)
62
- task_index += 1
63
- running_tasks.append(task)
64
- except StopIteration:
65
- have_new = False
66
-
67
- # Advance every previously-yielded task by one step.
68
- for task in previous_tasks:
69
- if task.step():
70
- running_tasks.append(task)
71
-
72
- previous_tasks = running_tasks
73
- except BaseException:
74
- # Clean up all in-flight generators to release GPU resources.
75
- for task in previous_tasks:
76
- task.close()
77
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/core.py DELETED
@@ -1,219 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
- from typing import List
5
-
6
- import torch
7
- from torch.distributed import ProcessGroup
8
- from torch.distributed.tensor import DTensor
9
-
10
- # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
- # parameter FQNs. Activation checkpointing similarly inserts
12
- # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
- # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
- _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def normalize_fqn(name: str) -> str:
20
- """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
- return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
-
23
-
24
- @dataclass
25
- class _muon_state:
26
- worker_rank: int
27
- process_group: ProcessGroup
28
- rank_indices: dict[int, tuple] # local_rank -> per-dim indices
29
- rank_numels: dict[int, int] # local_rank -> numel
30
- name: str
31
- qk_clip_state: torch.Tensor | None = None
32
-
33
-
34
- def _batch_momentum(
35
- grads: List[torch.Tensor],
36
- momentum_bufs: List[torch.Tensor],
37
- momentum: torch.Tensor,
38
- ) -> None:
39
- """Batched momentum update (no nesterov)."""
40
- torch._foreach_mul_(momentum_bufs, momentum)
41
- torch._foreach_add_(momentum_bufs, grads)
42
-
43
-
44
- def _batch_momentum_nesterov(
45
- grads: List[torch.Tensor],
46
- momentum_bufs: List[torch.Tensor],
47
- momentum: torch.Tensor,
48
- ) -> None:
49
- """Batched momentum update with nesterov correction."""
50
- torch._foreach_mul_(momentum_bufs, momentum)
51
- torch._foreach_add_(momentum_bufs, grads)
52
- nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
- torch._foreach_add_(grads, nesterov_terms)
54
-
55
-
56
- _compiled_momentum: dict[bool, callable] = {}
57
- _use_momentum_compile = True
58
-
59
-
60
- def set_momentum_compile(enabled: bool):
61
- """Toggle torch.compile for batched momentum."""
62
- global _use_momentum_compile
63
- _use_momentum_compile = enabled
64
-
65
-
66
- def batch_pre_ortho(
67
- grads: List[torch.Tensor],
68
- momentum_bufs: List[torch.Tensor],
69
- momentum: torch.Tensor,
70
- nesterov: bool,
71
- ) -> None:
72
- """Batched momentum update on lists of plain tensors.
73
-
74
- Mirrors dion's ``muon_update_pre_orthogonalize``.
75
- Inputs must be plain CUDA tensors (not DTensor).
76
- Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
-
78
- When compile is enabled, uses separately compiled functions for
79
- nesterov=True/False to avoid graph breaks from the branch.
80
- """
81
- fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
- if _use_momentum_compile:
83
- if nesterov not in _compiled_momentum:
84
- _compiled_momentum[nesterov] = torch.compile(fn)
85
- fn = _compiled_momentum[nesterov]
86
- fn(grads, momentum_bufs, momentum)
87
-
88
-
89
- def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
- """Weight-decay + update on plain tensors.
91
-
92
- Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
- lookup per call × 256+ params = massive overhead. The pipeline path uses
94
- batched _foreach_* ops instead; this function remains for base() and
95
- distributed_muon().
96
- """
97
- p_data.mul_(1 - lr * weight_decay)
98
- p_data.add_(u_data, alpha=-adjusted_lr)
99
-
100
-
101
- def update_p(p, u, lr, adjusted_lr, weight_decay):
102
- """Apply weight decay and orthogonalized update to parameter.
103
-
104
- Args:
105
- p: Parameter (torch.nn.Parameter or DTensor).
106
- u: Orthogonalized update tensor.
107
- lr: Base learning rate.
108
- adjusted_lr: Size-adjusted learning rate.
109
- weight_decay: Weight decay coefficient.
110
- """
111
- # Unwrap Parameter -> underlying data tensor.
112
- p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
- # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
- if isinstance(p_data, DTensor):
115
- p_data = p_data._local_tensor
116
- u_data = u._local_tensor if isinstance(u, DTensor) else u
117
- _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
118
-
119
-
120
- def adjust_lr_for_muon(lr, param_shape):
121
- """Scale learning rate based on parameter matrix dimensions.
122
-
123
- Args:
124
- lr: Base learning rate.
125
- param_shape: Shape of the parameter tensor.
126
-
127
- Returns:
128
- Adjusted learning rate.
129
- """
130
- A, B = param_shape[:2]
131
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
132
- # as described in the paper
133
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
134
- adjusted_lr = lr * adjusted_ratio
135
- return adjusted_lr
136
-
137
-
138
- def _match_key(parts, key):
139
- """Check if key matches as contiguous components in parts.
140
-
141
- Single-component keys (e.g. "experts") match any single component.
142
- Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
- """
144
- key_parts = key.split(".")
145
- key_len = len(key_parts)
146
- if key_len == 1:
147
- return key in parts
148
- return any(parts[i:i + key_len] == key_parts
149
- for i in range(len(parts) - key_len + 1))
150
-
151
-
152
- def is_expert_param(name, expert_keys):
153
- """Check if a parameter name matches any expert key (component-level)."""
154
- if not expert_keys:
155
- return False
156
- parts = normalize_fqn(name).split(".")
157
- return any(_match_key(parts, key) for key in expert_keys)
158
-
159
-
160
- def default_is_muon(name, x, expert_keys=None):
161
- normalized = normalize_fqn(name)
162
- parts = normalized.split(".")
163
- skip_keys = [
164
- "embed_tokens",
165
- "lm_head",
166
- "tok_embeddings",
167
- "output",
168
- "mhc_attn",
169
- "mhc_ffn",
170
- "lambda_proj",
171
- ]
172
- if any(key in parts for key in skip_keys):
173
- logger.info(
174
- "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
- normalized, name, x.ndim)
176
- return False
177
- effective_ndim = x.ndim
178
- is_expert = is_expert_param(name, expert_keys)
179
- if is_expert:
180
- effective_ndim -= 1
181
- result = effective_ndim >= 2
182
- logger.info(
183
- "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
- normalized, name, x.ndim, is_expert, effective_ndim,
185
- "Muon" if result else "AdamW")
186
- return result
187
-
188
-
189
- def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
190
- if is_muon_func is None:
191
- is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
-
193
- muon_params, muon_names = [], []
194
- non_muon_params, non_muon_names = [], []
195
-
196
- for n, p in model.named_parameters():
197
- if not p.requires_grad:
198
- continue
199
- if is_muon_func(n, p):
200
- muon_params.append(p)
201
- muon_names.append(n)
202
- else:
203
- non_muon_params.append(p)
204
- non_muon_names.append(n)
205
-
206
- logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
- expert_keys, len(muon_names), len(non_muon_names))
208
-
209
- return [
210
- {
211
- "params": muon_params,
212
- "names": muon_names,
213
- "use_muon": True,
214
- },
215
- {
216
- "params": non_muon_params,
217
- "use_muon": False,
218
- },
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py DELETED
@@ -1,206 +0,0 @@
1
- """CPU offloading for optimizer states.
2
-
3
- Manages a pinned CPU memory pool and async CUDA streams to offload
4
- optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
- optimizer steps, freeing GPU memory.
6
-
7
- All tracked tensors are packed into a single flat pinned CPU buffer
8
- (per dtype). D2H and H2D copies are performed per-tensor directly
9
- between individual GPU tensors and their slice of the CPU flat buffer
10
- — no GPU staging buffer is allocated, so there is **no temporary GPU
11
- memory spike** during offload or reload.
12
-
13
- Individual tensor storages are freed after offload via
14
- ``untyped_storage().resize_(0)``, preserving tensor identity so
15
- downstream caches remain valid.
16
- """
17
-
18
- import logging
19
- from collections import defaultdict
20
-
21
- import torch
22
- from torch.distributed.tensor import DTensor
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- class CPUOffloadPool:
28
- """Pinned CPU memory pool for async optimizer state offloading.
29
-
30
- Tracked tensors are grouped by dtype. Each group gets a single flat
31
- pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
- the flat buffer) to avoid allocating a GPU staging buffer.
33
- """
34
-
35
- def __init__(self):
36
- self._managed: list[torch.Tensor] = []
37
- self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
-
39
- # Per-dtype group: populated on first offload.
40
- # dtype → dict with keys:
41
- # "indices" : list[int] managed-list indices
42
- # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
- # "total" : int total numel
44
- # "cpu_flat" : Tensor pinned CPU buffer
45
- self._groups: dict[torch.dtype, dict] = {}
46
-
47
- self._offload_stream: torch.cuda.Stream | None = None
48
- self._device: torch.device | None = None
49
- self._initialized: bool = False
50
- self._logged: bool = False
51
-
52
- # ------------------------------------------------------------------
53
- @staticmethod
54
- def _local(t: torch.Tensor) -> torch.Tensor:
55
- """Unwrap DTensor to its local CUDA tensor."""
56
- return t._local_tensor if isinstance(t, DTensor) else t
57
-
58
- def _ensure_stream(self):
59
- if self._offload_stream is None:
60
- self._offload_stream = torch.cuda.Stream(device=self._device)
61
-
62
- # ------------------------------------------------------------------
63
- def track(self, tensor: torch.Tensor):
64
- """Register a GPU tensor for CPU offloading. Idempotent."""
65
- tid = id(tensor)
66
- if tid in self._storage_nbytes:
67
- return
68
- local = self._local(tensor)
69
- if self._device is None:
70
- self._device = local.device
71
- storage = local.untyped_storage()
72
- # Skip tensors with empty storage (e.g. empty FSDP shards)
73
- if storage.size() == 0:
74
- return
75
- self._storage_nbytes[tid] = storage.size()
76
- self._managed.append(tensor)
77
-
78
- # ------------------------------------------------------------------
79
- def _init_buffers(self):
80
- """Build per-dtype flat buffers on first offload."""
81
- # Group managed tensors by dtype.
82
- dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
83
- for idx, t in enumerate(self._managed):
84
- local = self._local(t)
85
- dtype_map[local.dtype].append((idx, local.numel()))
86
-
87
- total_cpu_bytes = 0
88
- for dtype, entries in dtype_map.items():
89
- offsets: list[tuple[int, int]] = []
90
- indices: list[int] = []
91
- off = 0
92
- for idx, n in entries:
93
- indices.append(idx)
94
- offsets.append((off, n))
95
- off += n
96
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
97
- self._groups[dtype] = {
98
- "indices": indices,
99
- "offsets": offsets,
100
- "total": off,
101
- "cpu_flat": cpu_flat,
102
- }
103
- total_cpu_bytes += off * cpu_flat.element_size()
104
-
105
- self._initialized = True
106
- logger.info(
107
- "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
108
- "%.2f MB pinned CPU memory",
109
- len(self._managed),
110
- len(self._groups),
111
- total_cpu_bytes / (1024**2),
112
- )
113
-
114
- # ------------------------------------------------------------------
115
- def offload(self):
116
- """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
117
- if not self._managed:
118
- return
119
- if not self._initialized:
120
- self._init_buffers()
121
- self._ensure_stream()
122
-
123
- # Offload stream waits for compute to finish.
124
- compute_event = torch.cuda.current_stream(self._device).record_event()
125
- self._offload_stream.wait_event(compute_event)
126
-
127
- offloaded_bytes = 0
128
-
129
- # Per-tensor D2H copies directly into CPU flat buffer slices.
130
- # No GPU staging buffer → no temporary GPU memory spike.
131
- with torch.cuda.stream(self._offload_stream):
132
- for dtype, grp in self._groups.items():
133
- indices = grp["indices"]
134
- offsets = grp["offsets"]
135
- cpu_flat = grp["cpu_flat"]
136
-
137
- for i, mgd_idx in enumerate(indices):
138
- local = self._local(self._managed[mgd_idx])
139
- off, n = offsets[i]
140
- cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
141
-
142
- offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
-
144
- # Wait for all D2H copies to land, then free GPU storage.
145
- self._offload_stream.synchronize()
146
- for t in self._managed:
147
- storage = self._local(t).untyped_storage()
148
- if storage.size() != 0:
149
- storage.resize_(0)
150
- else:
151
- raise RuntimeError(
152
- f"Tensor storage is already freed (size=0) before offload. "
153
- f"This indicates a double-free or external interference. "
154
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
155
- )
156
-
157
- if not self._logged:
158
- logger.info(
159
- "[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
160
- offloaded_bytes / (1024**2),
161
- )
162
-
163
- # ------------------------------------------------------------------
164
- def reload(self):
165
- """Per-tensor H2D from CPU flat buffer on the default stream.
166
-
167
- Runs on the current (default) CUDA stream to avoid stream
168
- interaction issues with the parallel Muon pipeline. Since
169
- pinned CPU memory is the source, the copies overlap with
170
- GPU idle time between steps.
171
- """
172
- if not self._managed or not self._initialized:
173
- return
174
-
175
- reloaded_bytes = 0
176
-
177
- # Re-allocate all GPU storages first.
178
- for t in self._managed:
179
- local = self._local(t)
180
- storage = local.untyped_storage()
181
- if storage.size() != 0:
182
- raise RuntimeError(
183
- f"Storage should have been freed (size=0) before reload, "
184
- f"but got size={storage.size()}. "
185
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
186
- )
187
- storage.resize_(self._storage_nbytes[id(t)])
188
-
189
- # Per-tensor H2D copies from CPU flat buffer slices.
190
- # non_blocking=True with pinned source allows DMA overlap.
191
- for dtype, grp in self._groups.items():
192
- indices = grp["indices"]
193
- offsets = grp["offsets"]
194
- cpu_flat = grp["cpu_flat"]
195
-
196
- for i, mgd_idx in enumerate(indices):
197
- local = self._local(self._managed[mgd_idx])
198
- off, n = offsets[i]
199
- local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
200
-
201
- reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
-
203
- if not self._logged:
204
- logger.info(
205
- "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2)
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def _is_shard(placement: Placement) -> bool:
11
- """Check if a placement is a shard type (Shard or _StridedShard).
12
-
13
- In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
- ``placement.is_shard()`` returns False for _StridedShard. This helper
15
- handles both old and new hierarchies.
16
- """
17
- return isinstance(placement, (Shard, _StridedShard))
18
-
19
-
20
- def get_slices_of_dtensor(
21
- target: DTensor | torch.Tensor,
22
- local_rank: int,
23
- shard_mesh: DeviceMesh,
24
- shard_placements: tuple[Placement],
25
- ) -> tuple[slice | torch.Tensor, ...]:
26
- """
27
- Get per-dimension indices for a given rank's shard of the target tensor.
28
-
29
- Uses ``Shard.local_shard_size_and_offset`` and
30
- ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
- both contiguous and strided (non-contiguous) sharding.
32
-
33
- Args:
34
- target (DTensor | torch.Tensor): The target tensor (for its shape).
35
- local_rank (int): The local rank within the shard group.
36
- shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
- shard_placements (tuple[Placement]): The shard placements.
38
-
39
- Returns:
40
- A tuple of indices (one per tensor dim). Each element is either:
41
- - A ``slice`` (for contiguous or unsharded dims)
42
- - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
- """
44
-
45
- # find the global rank of the local rank in the shard mesh
46
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
47
-
48
- rank_coords = (shard_mesh.mesh == rank).nonzero()
49
-
50
- assert len(rank_coords) == 1
51
- rank_coords = tuple(rank_coords[0].tolist())
52
-
53
- assert len(rank_coords) == len(shard_placements)
54
-
55
- # Track per-shard-dim indices.
56
- # None means "not yet sharded on this dim".
57
- dim_indices: dict[int, torch.Tensor] = {}
58
-
59
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
- # left-to-right sharding. This is ensured by the sorting logic of
61
- # construct_shard_mesh function.
62
- for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
- zip(rank_coords, shard_placements)):
64
- assert _is_shard(placement)
65
-
66
- num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
- shard_dim = placement.dim
68
-
69
- # Current effective size on this dim (may already be sub-sharded)
70
- if shard_dim in dim_indices:
71
- curr_size = len(dim_indices[shard_dim])
72
- else:
73
- curr_size = target.size()[shard_dim]
74
-
75
- # Compute indices for this level of sharding
76
- if isinstance(placement, _StridedShard):
77
- _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
78
- placement,
79
- curr_size,
80
- num_chunks,
81
- rank_coord,
82
- return_first_offset=False)
83
- new_indices = torch.tensor(offsets, dtype=torch.long)
84
- else:
85
- shard_size, offset = Shard.local_shard_size_and_offset(
86
- curr_size, num_chunks, rank_coord)
87
- new_indices = torch.arange(offset,
88
- offset + shard_size,
89
- dtype=torch.long)
90
-
91
- # Compose with previous indices on this dim
92
- if shard_dim in dim_indices:
93
- dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
94
- else:
95
- dim_indices[shard_dim] = new_indices
96
-
97
- # Build result tuple
98
- result: list[slice | torch.Tensor] = []
99
- for d in range(len(target.size())):
100
- if d not in dim_indices:
101
- result.append(slice(None))
102
- else:
103
- indices = dim_indices[d]
104
- # Convert contiguous indices to slice for efficiency
105
- if len(indices) > 0:
106
- start = indices[0].item()
107
- expected = torch.arange(start,
108
- start + len(indices),
109
- dtype=torch.long)
110
- if torch.equal(indices, expected):
111
- result.append(slice(start, start + len(indices)))
112
- else:
113
- result.append(indices)
114
- else:
115
- result.append(slice(0, 0))
116
-
117
- return tuple(result)
118
-
119
-
120
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
121
- ProcessGroup]] = dict()
122
-
123
-
124
- def construct_shard_mesh(
125
- placements: tuple[Placement],
126
- mesh: DeviceMesh,
127
- ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
128
- """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
129
-
130
- Given a DTensor's placements and device mesh, extracts the "shard group"
131
- — the set of ranks that together hold all shards of the same replica —
132
- and creates a ProcessGroup for all-to-all among them.
133
-
134
- Steps:
135
- 1. Sort placements: Replicate first, then Shard by (dim, granularity).
136
- 2. Permute the mesh tensor to match the sorted order.
137
- 3. Collapse Replicate dims → list of shard sub-meshes (one per replica).
138
- 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
139
-
140
- Example — 8 GPUs, mesh shape (2, 2, 2),
141
- placements ``[Shard(0), Replicate, _StridedShard(0)]``::
142
-
143
- Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)]
144
- Permutation: [1, 2, 0]
145
-
146
- Step 2 — Permute mesh dims by [1, 2, 0]:
147
- Original: Permuted:
148
- [[[0,1],[2,3]], [[[0,2],[1,3]],
149
- [[4,5],[6,7]]] [[4,6],[5,7]]]
150
-
151
- Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
152
- sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
153
- sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
154
- shard_placements = (_StridedShard(0), Shard(0))
155
-
156
- Step 4 — Rank 0 → ProcessGroup([0,1,4,5])
157
- Rank 2 → ProcessGroup([2,3,6,7])
158
-
159
- Returns:
160
- ``(shard_mesh, process_group, shard_placements)``
161
- """
162
- my_rank = dist.get_rank()
163
- assert mesh.mesh.device.type == 'cpu'
164
-
165
- # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
- # overhead of dist.new_group(). The standard path below also handles
168
- # subset calls safely via use_local_synchronization=True, but this
169
- # fast path is still beneficial for the common 1D shard case.
170
- if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
- key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
- if key not in _ranks_to_dist_cache:
173
- _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
174
- return (*_ranks_to_dist_cache[key], tuple(placements))
175
-
176
- mesh_tensor = mesh.mesh.clone()
177
-
178
- # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
179
- # _StridedShard comes BEFORE regular Shard on the same dim so that
180
- # get_slices_of_dtensor applies the outer sharding first, matching
181
- # DTensor's left-to-right (outer-to-inner) composition order.
182
- def _sort_key(item):
183
- index, placement = item
184
- assert not placement.is_partial(), "Partial placement not supported"
185
- if placement.is_replicate():
186
- return (-1, 0, index)
187
- assert _is_shard(placement), f"Unsupported: {type(placement)}"
188
- split = (-1 / placement.split_factor if isinstance(
189
- placement, _StridedShard) else 0)
190
- return (placement.dim, split, index)
191
-
192
- indexed = sorted(enumerate(placements), key=_sort_key)
193
- perm, sorted_placements = zip(*indexed)
194
-
195
- # -- Step 2: Permute mesh to match sorted placement order. --------------
196
- sorted_mesh = mesh_tensor.permute(perm)
197
-
198
- # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
199
- # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
200
- num_rep = sum(1 for p in sorted_placements if p.is_replicate())
201
- if num_rep > 0:
202
- if num_rep > 1:
203
- sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
204
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
205
- else:
206
- shard_meshes = [sorted_mesh]
207
- shard_placements = sorted_placements[num_rep:]
208
- assert len(shard_placements) == len(set(shard_placements))
209
-
210
- # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
- # Each rank only creates the group it belongs to, using
212
- # use_local_synchronization=True so that only group members need to
213
- # coordinate. This avoids deadlocks when different PP stages call
214
- # construct_shard_mesh for different parameters.
215
- def _cache_key(t: torch.Tensor) -> tuple:
216
- return (*t.shape, *t.flatten().tolist())
217
-
218
- my_key = None
219
- for sm in shard_meshes:
220
- if (my_rank == sm).any().item():
221
- key = _cache_key(sm)
222
- assert my_key is None, "Rank appears in multiple shard groups"
223
- my_key = key
224
- if key not in _ranks_to_dist_cache:
225
- pg = dist.new_group(sm.flatten().tolist(),
226
- use_local_synchronization=True)
227
- _ranks_to_dist_cache[key] = (
228
- DeviceMesh(device_type="cuda", mesh=sm),
229
- pg,
230
- )
231
-
232
- return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,122 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- restore_value=['y'],
47
- )
48
- @triton.jit
49
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
50
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
51
- GROUP_SIZE_M: tl.constexpr):
52
- """
53
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
54
- The code is a simple adaptation from the triton `matmul` tutorial:
55
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
56
- """
57
- pid = tl.program_id(axis=0)
58
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
60
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
61
- group_id = pid // num_pid_in_group
62
- first_pid_m = group_id * GROUP_SIZE_M
63
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
64
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
65
- pid_n = (pid % num_pid_in_group) // group_size_m
66
- if pid_m > pid_n:
67
- return
68
-
69
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
71
- offs_k = tl.arange(0, BLOCK_SIZE_K)
72
- # we use a & b ptrs to denote different rows of x.
73
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
75
-
76
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
77
-
78
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
79
- a = tl.load(a_ptrs,
80
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
81
- other=0.0)
82
- b = tl.load(b_ptrs,
83
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
84
- other=0.0)
85
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
86
- a_ptrs += BLOCK_SIZE_K * stride_xk
87
- b_ptrs += BLOCK_SIZE_K * stride_xk
88
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
89
- # https://github.com/triton-lang/triton/issues/2252
90
- c = accumulator.to(x.dtype.element_ty)
91
-
92
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
94
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
95
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
96
- tl.store(c_ptrs, c, mask=c_mask)
97
-
98
- # transpose and copy
99
- if pid_m < pid_n:
100
- ct_ptrs = y + stride_ym * offs_cn[:,
101
- None] + stride_yn * offs_cm[None, :]
102
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
103
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
-
105
-
106
- @torch.library.custom_op("muon::matmul_transpose_assign",
107
- mutates_args=("d_out", ))
108
- def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
- """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
110
- d_in = d_in.contiguous()
111
- M, K = d_in.shape
112
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
113
- M, META['BLOCK_SIZE_M']), )
114
- with torch.cuda.device(d_in.device.index):
115
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
- d_out.stride(0), d_out.stride(1))
117
-
118
-
119
- @matmul_transpose_assign.register_fake
120
- def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
- """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "python-depends": []
3
- }
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/muon.py DELETED
@@ -1,1068 +0,0 @@
1
- import logging
2
- import types
3
- from collections import defaultdict
4
- from typing import Any
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from torch.distributed.tensor import DTensor, Replicate, Shard
9
- from torch.profiler import record_function
10
-
11
- from .adamw import _placement_cache, _tensor_cache, step_adamw
12
- from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
- get_default_muon_param_groups, is_expert_param, update_p)
15
- from .cpu_offload import CPUOffloadPool
16
- from .distributed.utils import (_is_shard, construct_shard_mesh,
17
- get_slices_of_dtensor)
18
- from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
- _zeropower_via_newtonschulz5,
20
- zeropower_via_newtonschulz5,
21
- zeropower_via_newtonschulz5_batched)
22
- from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
- from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def _expand_expert_params(names, params, expert_keys):
29
- """Expand expert params by splitting on dim 0 (expert dimension).
30
-
31
- Params whose name matches any key in ``expert_keys`` are treated as
32
- expert-parallel tensors. Their outermost dimension is the expert
33
- dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
34
- ``nn.Parameter`` views so that in-place updates propagate back to
35
- the original storage.
36
-
37
- Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
38
- if they are expert params, their key must be added to ``expert_keys``.
39
-
40
- The grad must already be set on each expert param (e.g. after momentum).
41
-
42
- For DTensor expert params, placements that shard on dim 0 (expert dim)
43
- are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
44
- preserved: each 2D slice is wrapped as a DTensor on the corresponding
45
- submesh so the parallel pipeline handles the TP communication.
46
- """
47
- expanded_names = []
48
- expanded_params = []
49
-
50
- for n, p in zip(names, params):
51
- is_expert = is_expert_param(n, expert_keys)
52
- is_dtensor = isinstance(p.data, DTensor)
53
-
54
- if is_expert:
55
- if is_dtensor:
56
- logger.debug(
57
- "[expand_expert] %s: expert DTensor, shape=%s, "
58
- "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
- p.placements, p.device_mesh.mesh_dim_names,
60
- p.to_local().shape)
61
- else:
62
- logger.debug(
63
- "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
- p.data.shape)
65
-
66
- if not is_expert:
67
- assert p.data.ndim <= 2, (
68
- f"Param {n} has ndim={p.data.ndim} but does not match "
69
- f"expert_keys={expert_keys}. If this is an expert param, "
70
- f"add its key to expert_keys.")
71
- expanded_names.append(n)
72
- expanded_params.append(p)
73
- continue
74
-
75
- g = p.grad
76
- assert g is not None, (
77
- f"Expert param {n} must have grad set before expansion")
78
-
79
- tp_mesh = None
80
- tp_placements_2d = None
81
-
82
- if is_dtensor:
83
- local_data = p.to_local()
84
- local_grad = g.to_local() if isinstance(g, DTensor) else g
85
-
86
- # Find non-dim-0 shard placements (e.g. TP sharding).
87
- # After splitting on dim 0, Shard(k) becomes Shard(k-1).
88
- tp_dim_indices = []
89
- tp_placements_2d = []
90
- for i, pl in enumerate(p.placements):
91
- if _is_shard(pl) and pl.dim != 0:
92
- tp_dim_indices.append(i)
93
- tp_placements_2d.append(Shard(pl.dim - 1))
94
-
95
- if tp_dim_indices:
96
- tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
97
- for i in tp_dim_indices)
98
- if len(tp_dim_names) == 1:
99
- tp_mesh = p.device_mesh[tp_dim_names[0]]
100
- else:
101
- tp_mesh = p.device_mesh[tp_dim_names]
102
- else:
103
- local_data = p.data
104
- local_grad = g
105
-
106
- # Expand: split dim 0, reshape each slice to 2D.
107
- num_local_experts = local_data.shape[0]
108
- for i in range(num_local_experts):
109
- slice_data = local_data[i]
110
- slice_grad = local_grad[i]
111
-
112
- if tp_mesh is not None:
113
- # Wrap as DTensor on TP submesh so the pipeline handles
114
- # TP communication (gather/scatter across TP ranks).
115
- dt_data = DTensor.from_local(slice_data,
116
- device_mesh=tp_mesh,
117
- placements=tp_placements_2d)
118
- dt_grad = DTensor.from_local(slice_grad,
119
- device_mesh=tp_mesh,
120
- placements=tp_placements_2d)
121
- expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
122
- expert_param.grad = dt_grad
123
- else:
124
- expert_param = torch.nn.Parameter(slice_data,
125
- requires_grad=False)
126
- expert_param.grad = slice_grad
127
-
128
- expanded_names.append(f"{n}[{i}]")
129
- expanded_params.append(expert_param)
130
-
131
- p.grad = None # allow expert grad storage to be freed after pipeline
132
-
133
- return expanded_names, expanded_params
134
-
135
-
136
- class Muon(torch.optim.Optimizer):
137
- """
138
- Muon - MomentUm Orthogonalized by Newton-schulz
139
-
140
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
141
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
142
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
143
- the advantage that it can be stably run in bfloat16 on the GPU.
144
-
145
- Some warnings:
146
- - We believe this optimizer is unlikely to work well for training with small batch size.
147
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
148
-
149
- Arguments:
150
- model: The model to be optimized by Muon.
151
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
152
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
153
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
154
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
155
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
156
- weight_decay: The weight decay for Muon and AdamW.
157
- Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
158
- adamw_lr: The learning rate for the internal AdamW.
159
- adamw_betas: The betas for the internal AdamW.
160
- adamw_eps: The epsilon for the internal AdamW.
161
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
162
- debug: Whether to print debug information.
163
- clip_info : Configuration for QK clipping. Expected keys:
164
- - "q_indices" (list[int]): Indices of query heads to consider.
165
- - "k_indices" (list[int]): Indices of key heads to consider.
166
- - "head_dim" (int): Dimensionality of each attention head.
167
- - "threshold" (float): Threshold value; heads whose QK logits exceed
168
- this value will be scaled down.
169
- Default is:
170
- {
171
- "q_indices": [],
172
- "k_indices": [],
173
- "head_dim": 128,
174
- "threshold": 100
175
- }
176
- warmup_step : How many all2all gather, compute operations are launched in advance
177
- before the corresponding all2all scatter steps begin.
178
- A higher warmup_step increases memory usage but can improve
179
- performance by overlapping communication.
180
- Parallel muon only.
181
- chunk_size : Batch size of parameters to process in each
182
- all2all gather/compute/scatter step.
183
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
- For testing purpose only.
186
- expert_keys: List of strings to identify expert-parallel parameters.
187
- If any key appears in a parameter's name, its outermost
188
- dimension is treated as the expert dimension and expanded
189
- into per-expert 2D params for Muon. For example,
190
- ``expert_keys=["experts"]`` matches any param whose name
191
- contains "experts". 3D+ params not matched by any key
192
- will raise an error.
193
- """
194
-
195
- def __init__(self,
196
- params,
197
- lr=1e-3,
198
- momentum=0.95,
199
- nesterov=True,
200
- ns_steps=5,
201
- weight_decay=0.1,
202
- adamw_betas=(0.9, 0.95),
203
- adamw_eps=1e-8,
204
- none_grad=True,
205
- debug=False,
206
- clip_config=None,
207
- warmup_step=5,
208
- chunk_size=-1,
209
- use_distributed_muon=False,
210
- expert_keys=None):
211
- defaults = dict(
212
- lr=lr,
213
- weight_decay=weight_decay,
214
- momentum=momentum,
215
- nesterov=nesterov,
216
- ns_steps=ns_steps,
217
- adamw_betas=adamw_betas,
218
- adamw_eps=adamw_eps,
219
- none_grad=none_grad,
220
- use_muon=True,
221
- )
222
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
223
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
224
-
225
- if isinstance(params, types.GeneratorType):
226
- raise ValueError(error_message.format(idx=0) + instruction_code)
227
- for _idx, param_group in enumerate(params):
228
- if param_group.get("use_muon", None) is None:
229
- raise ValueError(
230
- error_message.format(idx=_idx) + instruction_code)
231
- super().__init__(params, defaults)
232
-
233
- self.debug = debug
234
- self.clip_config = clip_config if clip_config is not None else {
235
- "q_indices": [],
236
- "k_indices": [],
237
- "head_dim": 128,
238
- "threshold": 100,
239
- }
240
- self.warmup_step = warmup_step
241
- self.chunk_size = chunk_size
242
- self.use_distributed_muon = use_distributed_muon
243
- self.expert_keys = expert_keys
244
- self.cpu_offload = False
245
- self._cpu_offload_pool: CPUOffloadPool | None = None
246
- self._offload_initialized = False
247
- self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
- self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
249
-
250
- def _calc_flops(self, G, steps):
251
- assert len(G.shape) == 2
252
- M, N = G.shape
253
- if M > N:
254
- M, N = N, M
255
-
256
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
257
-
258
- def get_shard_mesh(self, p):
259
- """
260
- Get the shard mesh for a parameter p on the given rank.
261
- """
262
- assert isinstance(
263
- p, DTensor), "Parallel Muon only supports DTensor parameters."
264
-
265
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
266
- p.placements, p.device_mesh)
267
-
268
- return shard_mesh, shard_pg, shard_placements
269
-
270
- def init_state_and_assign_params(self, names, params, group, qk_logits):
271
- param_to_state = {}
272
- param_to_flops = {}
273
-
274
- total_flops = 0
275
- for p in params:
276
- g = p.grad
277
- if g is None:
278
- continue
279
- assert g.ndim == 2, "Muon only supports 2D parameters."
280
-
281
- flops = self._calc_flops(g, group["ns_steps"])
282
- param_to_flops[id(p)] = flops
283
- total_flops += flops
284
-
285
- if self.debug:
286
- logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
287
- total_flops / 1e12)
288
-
289
- paired = list(zip(names, params))
290
-
291
- paired_sorted = sorted(paired,
292
- key=lambda x: param_to_flops[id(x[1])],
293
- reverse=True)
294
-
295
- names_sorted, params_sorted = zip(*paired_sorted)
296
- ordered_names = list(names_sorted)
297
- ordered_params = list(params_sorted)
298
-
299
- round_robin = 0
300
- mesh = ordered_params[0].device_mesh
301
- placements = ordered_params[0].placements
302
-
303
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
304
- ordered_params[0])
305
- shard_mesh_flattened = shard_mesh.mesh.flatten()
306
- num_ranks = dist.get_world_size(group=shard_pg)
307
-
308
- for n, p in zip(ordered_names, ordered_params):
309
- if mesh != p.device_mesh:
310
- raise ValueError("All parameters must be on the same mesh.")
311
- if placements != p.placements:
312
- raise ValueError("All parameters must have same placements.")
313
-
314
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
315
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
316
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
317
-
318
- # Precompute per-rank indices and numels for all-to-all.
319
- rank_indices: dict[int, tuple] = {}
320
- rank_numels: dict[int, int] = {}
321
- for r in range(num_ranks):
322
- indices = get_slices_of_dtensor(p, r, shard_mesh,
323
- shard_placements)
324
- rank_indices[r] = indices
325
- numel = 1
326
- for idx, dim_size in zip(indices, p.shape):
327
- if isinstance(idx, slice):
328
- start, stop, step = idx.indices(dim_size)
329
- numel *= max(0, (stop - start + (step - 1)) // step)
330
- else:
331
- numel *= len(idx)
332
- rank_numels[r] = numel
333
-
334
- param_to_state[id(p)] = _muon_state(
335
- worker_rank=worker_rank,
336
- process_group=shard_pg,
337
- rank_indices=rank_indices,
338
- rank_numels=rank_numels,
339
- name=n,
340
- qk_clip_state=qk_clip_state,
341
- )
342
-
343
- return param_to_state, ordered_params
344
-
345
- def base(self, names, params, group, lr, weight_decay, qk_logits):
346
- # Momentum is already applied by _step_muon before this method.
347
- for n, p in zip(names, params):
348
- g = p.grad
349
- if g is None:
350
- continue
351
-
352
- u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
353
- steps=group["ns_steps"])
354
-
355
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
356
- update_p(p, u, lr, adjusted_lr, weight_decay)
357
-
358
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
359
-
360
- scales_full = compute_scales(
361
- p, qk_clip_state) if qk_clip_state is not None else None
362
- if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state)
364
-
365
- def distributed_muon(
366
- self,
367
- names: list[str],
368
- params: list[torch.nn.Parameter],
369
- group: dict[str, Any],
370
- lr: float,
371
- weight_decay: float,
372
- qk_logits: list[torch.Tensor | DTensor] | None,
373
- ):
374
- """Batched Distributed Muon — for testing/correctness verification only.
375
-
376
- Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
377
- the full grad, then slices back to local shards. This is simpler but
378
- slower than the parallel pipeline (all2all) path, so it serves as a
379
- reference implementation for verifying correctness.
380
- """
381
- with record_function("distributed_muon"):
382
- # Momentum is already applied by _step_muon before this method.
383
- ns_steps = group["ns_steps"]
384
-
385
- # Separate plain tensors (no communication) from DTensors.
386
- plain_names, plain_params = [], []
387
- dtensor_names, dtensor_params = [], []
388
- for n, p in zip(names, params):
389
- if p.grad is None:
390
- continue
391
- if isinstance(p.data, DTensor):
392
- dtensor_names.append(n)
393
- dtensor_params.append(p)
394
- else:
395
- plain_names.append(n)
396
- plain_params.append(p)
397
-
398
- # Process plain tensors per-param (no communication).
399
- for n, p in zip(plain_names, plain_params):
400
- u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
401
- steps=ns_steps)
402
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
403
- update_p(p, u, lr, adjusted_lr, weight_decay)
404
-
405
- qk_clip_state = get_qk_clip_info(self.clip_config, n,
406
- qk_logits)
407
- scales_full = compute_scales(
408
- p, qk_clip_state) if qk_clip_state is not None else None
409
- if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state)
411
-
412
- if not dtensor_params:
413
- return
414
-
415
- # Group DTensors by (placements, mesh) for batched all-gather.
416
- placement_groups: dict[tuple,
417
- tuple[list,
418
- list]] = defaultdict(lambda: ([], []))
419
- for n, p in zip(dtensor_names, dtensor_params):
420
- key = (p.placements, p.device_mesh)
421
- placement_groups[key][0].append(n)
422
- placement_groups[key][1].append(p)
423
-
424
- logger.info(
425
- "distributed_muon: %d placement groups, %d total dtensors",
426
- len(placement_groups), len(dtensor_params))
427
-
428
- for (placements, mesh), (grp_names,
429
- grp_params) in placement_groups.items():
430
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
431
- placements, mesh)
432
- rank = dist.get_rank(shard_pg)
433
- world_size = dist.get_world_size(shard_pg)
434
-
435
- logger.info(" group: %d params, placements=%s, world_size=%d",
436
- len(grp_params), placements, world_size)
437
-
438
- # Separate params that can be batched (all shard dims evenly
439
- # divisible) from those needing per-param full_tensor
440
- # (e.g. MoE gate weights with fewer rows than shard ranks).
441
- # all_gather_into_tensor requires equal buffer sizes across
442
- # ranks, so uneven splits must use DTensor full_tensor().
443
- batch_names, batch_params = [], []
444
- single_names, single_params = [], []
445
- for n, p in zip(grp_names, grp_params):
446
- even = all(p.shape[pl.dim] %
447
- shard_mesh.mesh.shape[dim_idx] == 0
448
- for dim_idx, pl in enumerate(shard_placements))
449
- if even:
450
- batch_names.append(n)
451
- batch_params.append(p)
452
- else:
453
- single_names.append(n)
454
- single_params.append(p)
455
-
456
- # Process uneven-split params per-param via full_tensor().
457
- for n, p in zip(single_names, single_params):
458
- with record_function("distributed_muon::newton_schulz"):
459
- g_full = p.grad.full_tensor().to(COMM_DTYPE)
460
- u_full = _zeropower_via_newtonschulz5(g_full,
461
- steps=ns_steps)
462
- del g_full
463
- with record_function("distributed_muon::update"):
464
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
465
- p._local_tensor.mul_(1 - lr * weight_decay)
466
- local_indices = get_slices_of_dtensor(
467
- p, rank, shard_mesh, shard_placements)
468
- u_local = u_full[local_indices]
469
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
470
- del u_full
471
-
472
- qk_clip_state = get_qk_clip_info(
473
- self.clip_config, n, qk_logits)
474
- scales_full = compute_scales(
475
- p, qk_clip_state
476
- ) if qk_clip_state is not None else None
477
- if scales_full is not None:
478
- ratio = p.shape[0] // scales_full.shape[0]
479
- idx0 = local_indices[0]
480
- if isinstance(idx0, slice):
481
- start = idx0.start or 0
482
- idx0 = torch.arange(start,
483
- idx0.stop,
484
- device=scales_full.device)
485
- row_scales = scales_full[idx0 // ratio]
486
- p._local_tensor.mul_(row_scales.view(-1, 1))
487
-
488
- if not batch_params:
489
- continue
490
-
491
- logger.info(" batched=%d, single=%d", len(batch_params),
492
- len(single_params))
493
-
494
- # Concat all local grad shards into a single flat buffer.
495
- with record_function("distributed_muon::gather"):
496
- grad_locals = [
497
- p.grad.to_local().to(COMM_DTYPE).flatten()
498
- for p in batch_params
499
- ]
500
- numels = [g.numel() for g in grad_locals]
501
- grad_concat = torch.cat(grad_locals)
502
- del grad_locals
503
-
504
- # Single all-gather (replaces N separate full_tensor).
505
- grad_gathered = torch.empty(
506
- grad_concat.numel() * world_size,
507
- dtype=COMM_DTYPE,
508
- device="cuda",
509
- )
510
- dist.all_gather_into_tensor(grad_gathered,
511
- grad_concat,
512
- group=shard_pg)
513
-
514
- total_numel = grad_concat.numel()
515
- del grad_concat
516
-
517
- # Precompute per-param offsets within the concat buffer.
518
- offsets = []
519
- off = 0
520
- for ne in numels:
521
- offsets.append(off)
522
- off += ne
523
-
524
- # Per-param: reconstruct full grad → NS → local update.
525
- for i, (n, p) in enumerate(zip(batch_names, batch_params)):
526
- with record_function("distributed_muon::newton_schulz"):
527
- g_full = torch.empty(p.shape,
528
- dtype=COMM_DTYPE,
529
- device="cuda")
530
- for r in range(world_size):
531
- r_start = r * total_numel + offsets[i]
532
- shard = grad_gathered[r_start:r_start + numels[i]]
533
- indices = get_slices_of_dtensor(
534
- p, r, shard_mesh, shard_placements)
535
- g_full[indices] = shard.reshape(
536
- g_full[indices].shape)
537
-
538
- u_full = _zeropower_via_newtonschulz5(g_full,
539
- steps=ns_steps)
540
- del g_full
541
-
542
- with record_function("distributed_muon::update"):
543
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
544
- p._local_tensor.mul_(1 - lr * weight_decay)
545
- local_indices = get_slices_of_dtensor(
546
- p, rank, shard_mesh, shard_placements)
547
- u_local = u_full[local_indices]
548
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
549
- del u_full
550
-
551
- qk_clip_state = get_qk_clip_info(
552
- self.clip_config, n, qk_logits)
553
- scales_full = compute_scales(
554
- p, qk_clip_state
555
- ) if qk_clip_state is not None else None
556
- if scales_full is not None:
557
- ratio = p.shape[0] // scales_full.shape[0]
558
- idx0 = local_indices[0]
559
- if isinstance(idx0, slice):
560
- start = idx0.start or 0
561
- idx0 = torch.arange(start,
562
- idx0.stop,
563
- device=scales_full.device)
564
- row_scales = scales_full[idx0 // ratio]
565
- p._local_tensor.mul_(row_scales.view(-1, 1))
566
-
567
- def _setup_parallel(self, names, params, group, qk_logits):
568
- """Compute (or retrieve cached) parallel pipeline metadata.
569
-
570
- Returns:
571
- (ordered_params, param_to_state, rank, chunk_size)
572
- """
573
- cache_key = tuple(names)
574
-
575
- if cache_key not in self._parallel_cache:
576
- # First call: compute metadata and populate cache.
577
- param_to_state, ordered_params = self.init_state_and_assign_params(
578
- names, params, group, qk_logits)
579
-
580
- shard_pg = param_to_state[id(ordered_params[0])].process_group
581
- rank = dist.get_rank(group=shard_pg)
582
-
583
- if self.chunk_size == -1:
584
- shard_ranks = dist.get_world_size(shard_pg)
585
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
586
- elif self.chunk_size > 0:
587
- chunk_size = self.chunk_size
588
- else:
589
- raise ValueError(
590
- "chunk_size must be -1 or a positive integer.")
591
-
592
- ordered_names = [
593
- param_to_state[id(p)].name for p in ordered_params
594
- ]
595
- name_to_state = {
596
- param_to_state[id(p)].name: param_to_state[id(p)]
597
- for p in ordered_params
598
- }
599
- self._parallel_cache[cache_key] = {
600
- 'ordered_names': ordered_names,
601
- 'name_to_state': name_to_state,
602
- 'rank': rank,
603
- 'chunk_size': chunk_size,
604
- }
605
- else:
606
- # Cached path: rebuild param_to_state with current id(p) keys.
607
- cache = self._parallel_cache[cache_key]
608
- rank = cache['rank']
609
- chunk_size = cache['chunk_size']
610
-
611
- name_to_param = dict(zip(names, params))
612
- ordered_params = [name_to_param[n] for n in cache['ordered_names']]
613
-
614
- param_to_state = {}
615
- for p, n in zip(ordered_params, cache['ordered_names']):
616
- cached_state = cache['name_to_state'][n]
617
- param_to_state[id(p)] = _muon_state(
618
- worker_rank=cached_state.worker_rank,
619
- process_group=cached_state.process_group,
620
- rank_indices=cached_state.rank_indices,
621
- rank_numels=cached_state.rank_numels,
622
- name=n,
623
- qk_clip_state=get_qk_clip_info(self.clip_config, n,
624
- qk_logits),
625
- )
626
-
627
- return ordered_params, param_to_state, rank, chunk_size
628
-
629
- def parallel(self,
630
- names,
631
- params,
632
- group,
633
- lr,
634
- weight_decay,
635
- qk_logits,
636
- prelaunch_gather=None):
637
- """
638
- Perform a parallel optimization step using Muon.
639
-
640
- Parameters are chunked and each chunk is processed by a
641
- :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
642
- interleaves multiple chunks so that communication and computation
643
- overlap across chunks (the same overlap previously achieved by the
644
- warmup + main-loop index scheduling).
645
-
646
- If ``prelaunch_gather`` is provided, it is passed to the first
647
- chunk's generator to skip re-launching the already in-flight
648
- A2A gather.
649
- """
650
-
651
- # Momentum is already applied by _step_muon before this method.
652
-
653
- ordered_params, param_to_state, rank, chunk_size = (
654
- self._setup_parallel(names, params, group, qk_logits))
655
-
656
- def pipelines():
657
- first = True
658
- for start in range(0, len(ordered_params), chunk_size):
659
- chunk = ordered_params[start:start + chunk_size]
660
- if chunk:
661
- kwargs = dict(
662
- params=chunk,
663
- param_to_state=param_to_state,
664
- rank=rank,
665
- ns_steps=group["ns_steps"],
666
- lr=lr,
667
- weight_decay=weight_decay,
668
- none_grad=group["none_grad"],
669
- )
670
- if first and prelaunch_gather is not None:
671
- kwargs['prelaunch_gather'] = prelaunch_gather
672
- first = False
673
- yield muon_chunk_pipeline(**kwargs)
674
-
675
- with record_function("muon::pipeline"):
676
- run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
677
-
678
- def _step_muon(self, group, qk_logits=None):
679
- params = group["params"]
680
- lr = group["lr"]
681
- weight_decay = group["weight_decay"]
682
- momentum = group["momentum"]
683
- names = group["names"]
684
-
685
- # Apply momentum to all params before routing/expansion.
686
- # Batched using _foreach_* ops (compiled, fullgraph=True).
687
- with record_function("muon::momentum"):
688
- active_params = [p for p in params if p.grad is not None]
689
- if active_params:
690
- # Ensure momentum buffers exist (avoid zeros_like when already present).
691
- for p in active_params:
692
- if "momentum_buffer" not in self.state[p]:
693
- self.state[p]["momentum_buffer"] = torch.zeros_like(
694
- p.grad)
695
-
696
- # Extract local tensors for compiled batch function.
697
- local_grads = [
698
- p.grad._local_tensor
699
- if isinstance(p.grad, DTensor) else p.grad
700
- for p in active_params
701
- ]
702
- local_bufs = [
703
- self.state[p]["momentum_buffer"]._local_tensor
704
- if isinstance(self.state[p]["momentum_buffer"], DTensor)
705
- else self.state[p]["momentum_buffer"]
706
- for p in active_params
707
- ]
708
-
709
- # Wrap momentum as tensor for torch.compile.
710
- batch_pre_ortho(local_grads, local_bufs,
711
- torch.tensor(momentum), group["nesterov"])
712
-
713
- # For non-nesterov, the result is the momentum buffer.
714
- if not group["nesterov"]:
715
- for p in active_params:
716
- p.grad = self.state[p]["momentum_buffer"]
717
-
718
- # Identify batched experts for deferred NS.
719
- # Detection is cheap (condition checks only); actual NS compute is
720
- # deferred so it can overlap with the first chunk's A2A gather.
721
- deferred_expert_work = []
722
- if self.expert_keys:
723
- batched_expert_indices = []
724
- for i, (n, p) in enumerate(zip(names, params)):
725
- if not (is_expert_param(n, self.expert_keys)
726
- and p.grad is not None):
727
- continue
728
- # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
729
- if isinstance(p.data, DTensor):
730
- has_tp = any(
731
- _is_shard(pl) and pl.dim != 0 for pl in p.placements)
732
- if has_tp:
733
- continue
734
- batched_expert_indices.append(i)
735
-
736
- if batched_expert_indices:
737
- # Save refs for deferred NS; free grads from param list.
738
- for i in batched_expert_indices:
739
- p = params[i]
740
- g = p.grad
741
- local_g = (g._local_tensor
742
- if isinstance(g, DTensor) else g)
743
- local_data = (p.data._local_tensor if isinstance(
744
- p.data, DTensor) else p.data)
745
- deferred_expert_work.append((local_data, local_g))
746
- p.grad = None
747
-
748
- # Remove batched experts from lists before expansion.
749
- keep = sorted(
750
- set(range(len(params))) - set(batched_expert_indices))
751
- names = [names[i] for i in keep]
752
- params = [params[i] for i in keep]
753
-
754
- def _run_deferred_expert_ns():
755
- """Execute deferred batched expert NS."""
756
- if not deferred_expert_work:
757
- return
758
- with record_function("muon::batched_expert_ns"):
759
- ns_steps = group["ns_steps"]
760
- for local_data, local_g in deferred_expert_work:
761
- u = zeropower_via_newtonschulz5_batched(
762
- local_g.to(COMM_DTYPE), steps=ns_steps)
763
- adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
764
- local_data.mul_(1 - lr * weight_decay)
765
- local_data.add_(u, alpha=-adjusted_lr)
766
-
767
- # Expand expert params by splitting on dim 0.
768
- logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
769
- len(params), self.expert_keys)
770
- if self.expert_keys:
771
- cache_key = tuple(id(p) for p in params)
772
- cache = self._expert_expand_cache.get(cache_key)
773
-
774
- if cache is None:
775
- # Cold path: full expansion + build cache metadata.
776
- exp_names, exp_params = _expand_expert_params(
777
- names, params, self.expert_keys)
778
-
779
- # Build per-expert-group info for hot-path grad updates.
780
- grad_info = []
781
- exp_idx = 0
782
- for orig_idx, (n, p) in enumerate(zip(names, params)):
783
- if not is_expert_param(n, self.expert_keys):
784
- exp_idx += 1
785
- continue
786
-
787
- is_dt = isinstance(p.data, DTensor)
788
- num_experts = (p.to_local() if is_dt else p.data).shape[0]
789
-
790
- # Detect TP mesh from the first expanded expert param.
791
- tp_mesh = None
792
- tp_pls = None
793
- sample = exp_params[exp_idx]
794
- if isinstance(sample.data, DTensor):
795
- tp_mesh = sample.data.device_mesh
796
- tp_pls = list(sample.data.placements)
797
-
798
- grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
799
- tp_mesh, tp_pls))
800
- exp_idx += num_experts
801
-
802
- self._expert_expand_cache[cache_key] = {
803
- 'names': exp_names,
804
- 'params': exp_params,
805
- 'grad_info': grad_info,
806
- }
807
- names, params = exp_names, exp_params
808
- else:
809
- # Hot path: reuse cached params, only update expert grads.
810
- for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
811
- tp_pls) in cache['grad_info']:
812
- p = params[orig_idx]
813
- g = p.grad
814
- local_grad = (g.to_local()
815
- if is_dt and isinstance(g, DTensor) else g)
816
- for i in range(num_experts):
817
- expert_p = cache['params'][exp_start + i]
818
- sg = local_grad[i]
819
- if tp_mesh is not None:
820
- expert_p.grad = DTensor.from_local(
821
- sg, device_mesh=tp_mesh, placements=tp_pls)
822
- else:
823
- expert_p.grad = sg
824
- p.grad = None
825
-
826
- names = cache['names']
827
- params = cache['params']
828
- else:
829
- names, params = _expand_expert_params(names, params,
830
- self.expert_keys)
831
- logger.debug("[_step_muon] after expand: %d params", len(params))
832
-
833
- param_dtensors = []
834
- name_dtensors = []
835
-
836
- param_tensors = []
837
- name_tensors = []
838
-
839
- # distributed_muon is a reference implementation for testing only.
840
- # The parallel pipeline (all2all) path below is the production path.
841
- if self.use_distributed_muon:
842
- _run_deferred_expert_ns()
843
- self.distributed_muon(names=names,
844
- params=params,
845
- group=group,
846
- lr=lr,
847
- weight_decay=weight_decay,
848
- qk_logits=qk_logits)
849
- return
850
-
851
- for n, p in zip(names, params):
852
- if p is None or p.grad is None:
853
- continue
854
- if isinstance(p.data, DTensor):
855
- if all(
856
- isinstance(placement, Replicate)
857
- for placement in p.placements):
858
- logger.debug(
859
- "[route] %s → base (DTensor all-Replicate), "
860
- "shape=%s, placements=%s", n, p.shape, p.placements)
861
- param_tensors.append(p)
862
- name_tensors.append(n)
863
- else:
864
- logger.debug(
865
- "[route] %s → parallel (DTensor), shape=%s, "
866
- "placements=%s, mesh=%s", n, p.shape, p.placements,
867
- p.device_mesh.mesh_dim_names)
868
- param_dtensors.append(p)
869
- name_dtensors.append(n)
870
- elif isinstance(p.data, torch.Tensor):
871
- logger.debug("[route] %s → base (plain tensor), shape=%s", n,
872
- p.data.shape)
873
- param_tensors.append(p)
874
- name_tensors.append(n)
875
- else:
876
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
877
-
878
- logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
879
- f"{len(param_tensors)} Tensors → base")
880
-
881
- def group_dtensors(dtensors, names):
882
- # To support different placements, we group parameters by placements
883
- # and run parallel Muon on each group.
884
-
885
- placement_to_params = defaultdict(lambda: ([], []))
886
-
887
- assert len(dtensors) == len(names)
888
- for p, n in zip(dtensors, names):
889
- placement_to_params[tuple([p.placements,
890
- p.device_mesh])][0].append(n)
891
- placement_to_params[tuple([p.placements,
892
- p.device_mesh])][1].append(p)
893
- return placement_to_params
894
-
895
- if len(param_dtensors) > 0:
896
- if not dist.is_initialized():
897
- raise RuntimeError(
898
- "Parallel Muon requires torch.distributed to be initialized."
899
- )
900
-
901
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
902
-
903
- # Pre-launch the first chunk's A2A gather so that the NCCL
904
- # communication overlaps with the (deferred) batched expert NS
905
- # compute on the default CUDA stream.
906
- prelaunch = None
907
- if deferred_expert_work:
908
- first_names, first_params = next(iter(dtensor_group.values()))
909
- ordered, pts, rnk, csz = self._setup_parallel(
910
- first_names, first_params, group, qk_logits)
911
- first_chunk = ordered[:csz]
912
- if first_chunk:
913
- prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
914
- group["none_grad"])
915
-
916
- _run_deferred_expert_ns()
917
-
918
- first_group = True
919
- for _, (names, params) in dtensor_group.items():
920
- pg = prelaunch if first_group else None
921
- first_group = False
922
- self.parallel(
923
- names,
924
- params,
925
- group,
926
- lr=lr,
927
- weight_decay=weight_decay,
928
- qk_logits=qk_logits,
929
- prelaunch_gather=pg,
930
- )
931
- else:
932
- _run_deferred_expert_ns()
933
-
934
- if len(param_tensors) > 0:
935
- self.base(
936
- name_tensors,
937
- param_tensors,
938
- group,
939
- lr=lr,
940
- weight_decay=weight_decay,
941
- qk_logits=qk_logits,
942
- )
943
-
944
- def _register_states_for_offload(self):
945
- """Register all optimizer state tensors with the CPU offload pool.
946
-
947
- Called once after the first step when states have been lazily created.
948
- Offloads all param states (momentum buffers for Muon, moment1/moment2
949
- for AdamW) to free GPU memory between steps.
950
- """
951
- pool = self._cpu_offload_pool
952
- tracked = 0
953
- for group in self.param_groups:
954
- for p in group["params"]:
955
- if p not in self.state:
956
- continue
957
- state = self.state[p]
958
- if group.get("use_muon", False):
959
- if "momentum_buffer" in state:
960
- pool.track(state["momentum_buffer"])
961
- tracked += 1
962
- else:
963
- if "moment1" in state:
964
- pool.track(state["moment1"])
965
- if "moment2" in state:
966
- pool.track(state["moment2"])
967
- tracked += 1
968
- logger.info("[CPUOffload] Registered %d param states for offload",
969
- tracked)
970
-
971
- @torch.no_grad
972
- def step(self, closure=None, qk_logits=None):
973
- """Perform a single optimization step.
974
-
975
- Args:
976
- closure (Callable, optional): A closure that reevaluates the model
977
- and returns the loss.
978
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
979
- to 1D tensors of shape (num_heads,), representing the maximum
980
- QK logits across all tokens, computed as
981
- (1 / sqrt(head_dim)) * (Q @ K^T).
982
- """
983
- loss = None
984
- if closure is not None:
985
- with torch.enable_grad():
986
- loss = closure()
987
-
988
- # H2D: reload optimizer states from CPU before computation.
989
- if self.cpu_offload and self._offload_initialized:
990
- self._cpu_offload_pool.reload()
991
-
992
- logger.debug("[Muon.step] expert_keys=%s, %d param groups",
993
- self.expert_keys, len(self.param_groups))
994
-
995
- for i, group in enumerate(self.param_groups):
996
- if group["use_muon"]:
997
- logger.debug("[Muon.step] group %d: use_muon=True, %d params",
998
- i, len(group["params"]))
999
- self._step_muon(group, qk_logits=qk_logits)
1000
- else:
1001
- logger.debug(
1002
- "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1003
- i, len(group["params"]))
1004
- step_adamw(self.state, group)
1005
-
1006
- # D2H: offload optimizer states to CPU after computation.
1007
- if self.cpu_offload:
1008
- if not self._offload_initialized:
1009
- if self._cpu_offload_pool is None:
1010
- self._cpu_offload_pool = CPUOffloadPool()
1011
- self._register_states_for_offload()
1012
- self._offload_initialized = True
1013
- self._cpu_offload_pool.offload()
1014
-
1015
- return loss
1016
-
1017
- # ------------------------------------------------------------------
1018
- # CPU offload public helpers
1019
- # ------------------------------------------------------------------
1020
-
1021
- def turn_on_cpu_offload(self):
1022
- """Enable CPU offload for optimizer states."""
1023
- if self.cpu_offload:
1024
- return
1025
- logger.info("[Muon] turn_on_cpu_offload")
1026
- self.cpu_offload = True
1027
- if not self.state:
1028
- return
1029
- self._cpu_offload_pool = CPUOffloadPool()
1030
- self._offload_initialized = False
1031
- self._register_states_for_offload()
1032
- self._offload_initialized = True
1033
- self._cpu_offload_pool.offload()
1034
-
1035
- def turn_off_cpu_offload(self):
1036
- """Disable CPU offload and keep optimizer states resident on GPU."""
1037
- if not self.cpu_offload:
1038
- return
1039
- logger.info("[Muon] turn_off_cpu_offload")
1040
- if self._offload_initialized:
1041
- self._cpu_offload_pool.reload()
1042
- torch.cuda.current_stream().synchronize()
1043
- self._cpu_offload_pool = None
1044
- self._offload_initialized = False
1045
- self.cpu_offload = False
1046
-
1047
- # ------------------------------------------------------------------
1048
- # Checkpoint support for cpu_offload
1049
- # ------------------------------------------------------------------
1050
-
1051
- def state_dict(self) -> dict:
1052
- if self.cpu_offload:
1053
- raise RuntimeError(
1054
- "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
- )
1056
- return super().state_dict()
1057
-
1058
- def load_state_dict(self, state_dict: dict) -> None:
1059
- if self.cpu_offload:
1060
- raise RuntimeError(
1061
- "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
- )
1063
- super().load_state_dict(state_dict)
1064
-
1065
- # Invalidate adamw.py's module-level tensor caches so that
1066
- # the next step rebuilds them with the newly loaded state tensors.
1067
- _placement_cache.clear()
1068
- _tensor_cache.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py DELETED
@@ -1,240 +0,0 @@
1
- from itertools import repeat
2
- from math import inf, sqrt
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from .matmul_transpose_triton import matmul_transpose_assign
8
-
9
- COMM_DTYPE = torch.bfloat16
10
- DEFAULT_CHUNK_SIZE_RATIO = 4
11
-
12
-
13
- def _optimal_quintic(l, u, max_iter=1000):
14
- """
15
- Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
- to the constant function x -> 1 over the interval [l, u].
17
-
18
- Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
- approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
- two interior equioscillation nodes q, r until convergence. Returns the
21
- closed-form equioscillating solution when l ≈ u.
22
-
23
- Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
- (NaN or inf). Raises RuntimeError if convergence is not reached within
25
- max_iter iterations.
26
- """
27
- assert 0 <= l <= u
28
- if 1 - 5e-6 <= l / u:
29
- return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
- q = (3 * l + u) / 4
31
- r = (l + 3 * u) / 4
32
- E = inf
33
- for _ in range(max_iter):
34
- old_E = E
35
- LHS = np.array(
36
- [
37
- [l, l**3, l**5, 1],
38
- [q, q**3, q**5, -1],
39
- [r, r**3, r**5, 1],
40
- [u, u**3, u**5, -1],
41
- ]
42
- )
43
- a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
- if not np.all(np.isfinite([a, b, c, E])):
45
- raise ValueError(
46
- f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
- )
48
- q, r = np.sqrt(
49
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
- )
51
- if not np.all(np.isfinite([q, r])):
52
- raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
53
- if abs(old_E - E) <= 1e-15:
54
- break
55
- else:
56
- raise RuntimeError(
57
- f"_optimal_quintic: did not converge after {max_iter} iterations"
58
- )
59
- return float(a), float(b), float(c)
60
-
61
-
62
- def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
63
- """
64
- Compute the Polar Express coefficient series for `num_iters` quintic iterations.
65
-
66
- Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
67
- compose to map singular values from [l, 1] toward 1. At each step:
68
- 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
69
- prevents near-zero singular values from stalling by raising the effective
70
- lower bound; if it is active (cushion*u > l), the coefficients are
71
- rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
72
- 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
73
- last iteration, providing numerical headroom at the cost of a slightly slower
74
- final convergence step.
75
- 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
76
-
77
- Returns a list of (a, b, c) tuples, one per iteration.
78
-
79
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
80
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
81
- """
82
- u = 1
83
- assert 0 <= l <= u
84
- safety_factor = 1 + safety_factor_eps
85
- coefficients = []
86
- for iter in range(num_iters):
87
- a, b, c = _optimal_quintic(max(l, cushion * u), u)
88
- if cushion * u > l:
89
- pl = a * l + b * l**3 + c * l**5
90
- pu = a * u + b * u**3 + c * u**5
91
- rescaler = 2 / (pl + pu)
92
- a *= rescaler
93
- b *= rescaler
94
- c *= rescaler
95
- if iter < num_iters - 1:
96
- a /= safety_factor
97
- b /= safety_factor**3
98
- c /= safety_factor**5
99
- coefficients.append((a, b, c))
100
- l = a * l + b * l**3 + c * l**5
101
- u = 2 - l
102
- return coefficients
103
-
104
-
105
- # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
106
- # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
107
- # approximant to x->1 over the current singular-value interval, computed once at
108
- # import time and reused across all optimizer steps.
109
- #
110
- # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
111
- # - Former: empirically tuned to maximize slope at zero; did not converge
112
- # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
113
- # of the true polar factor UV^T.
114
- # - Polar Express: analytically optimal per step, adapting to the shrinking
115
- # singular-value interval [l, u] as iterations progress; converges all
116
- # singular values to 1, producing the exact polar factor UV^T.
117
- _coeffs_list = _optimal_composition(
118
- l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
- )
120
-
121
-
122
- # This code is adapted from:
123
- # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
124
- # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
125
- # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
126
- @torch.no_grad()
127
- def _zeropower_via_newtonschulz5(G, steps):
128
- """
129
- Compute the polar factor of G via the Polar Express method.
130
-
131
- Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
132
- are the Polar Express coefficients from `_coeffs_list`. Each step is the
133
- optimal odd quintic approximant to x -> 1 over the current singular-value
134
- interval, minimizing the maximum approximation error (Remez / minimax criterion).
135
- The composition maps singular values from [l, 1] to near 1, producing the
136
- polar factor (orthogonal factor in the polar decomposition G = UP).
137
-
138
- `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
139
- cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
140
-
141
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
142
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
143
- """
144
- assert len(G.shape) == 2
145
- assert G.dtype == COMM_DTYPE
146
- X = G # no manual typecast
147
-
148
- if G.size(0) > G.size(1):
149
- X = X.T
150
-
151
- X = X / (X.norm() + 1e-7)
152
- hs = _coeffs_list[:steps] + list(
153
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
- )
155
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
- # Perform the NS iterations
158
- for a, b, c in hs:
159
- matmul_transpose_assign(X, buf1)
160
- matmul_transpose_assign(buf1, buf2)
161
- buf1.mul_(b).add_(buf2, alpha=c)
162
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
163
-
164
- if G.size(0) > G.size(1):
165
- X = X.T
166
-
167
- return X
168
-
169
-
170
- @torch.no_grad()
171
- def _zeropower_via_newtonschulz5_batched(G, steps):
172
- """Batched polar factor computation for 3D (E, out, in) tensors.
173
-
174
- Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
175
- ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
176
- processing all E expert matrices in a single batched call.
177
- """
178
- assert len(G.shape) == 3
179
- assert G.dtype == COMM_DTYPE
180
- X = G
181
-
182
- if G.size(1) > G.size(2):
183
- X = X.transpose(-2, -1)
184
-
185
- # Per-expert Frobenius norm.
186
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
-
188
- hs = _coeffs_list[:steps] + list(
189
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
- )
191
- for a, b, c in hs:
192
- buf1 = torch.bmm(X, X.transpose(-2, -1))
193
- buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
194
- buf1.mul_(b).add_(buf2, alpha=c)
195
- X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
196
-
197
- if G.size(1) > G.size(2):
198
- X = X.transpose(-2, -1)
199
-
200
- return X
201
-
202
-
203
- _ns_per_shape: dict[tuple[int, ...], callable] = {}
204
- _use_compile = True
205
-
206
-
207
- def set_ns_compile(enabled: bool):
208
- """Toggle torch.compile for Newton-Schulz iteration."""
209
- global _use_compile
210
- _use_compile = enabled
211
-
212
-
213
- def zeropower_via_newtonschulz5(G, steps=5):
214
- if not _use_compile:
215
- return _zeropower_via_newtonschulz5(G, steps)
216
- key = G.shape
217
- if key not in _ns_per_shape:
218
- _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
219
- options={
220
- "triton.cudagraphs": True,
221
- "shape_padding": False
222
- })
223
- torch.compiler.cudagraph_mark_step_begin()
224
- return _ns_per_shape[key](G, steps).clone()
225
-
226
-
227
- def zeropower_via_newtonschulz5_batched(G, steps=5):
228
- """Compile-cached batched Newton-Schulz for 3D expert tensors."""
229
- if not _use_compile:
230
- return _zeropower_via_newtonschulz5_batched(G, steps)
231
- key = G.shape
232
- if key not in _ns_per_shape:
233
- _ns_per_shape[key] = torch.compile(
234
- _zeropower_via_newtonschulz5_batched,
235
- options={
236
- "triton.cudagraphs": True,
237
- "shape_padding": False
238
- })
239
- torch.compiler.cudagraph_mark_step_begin()
240
- return _ns_per_shape[key](G, steps).clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/pipeline.py DELETED
@@ -1,468 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- import torch
5
- import torch.distributed as dist
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- from .core import _muon_state, adjust_lr_for_muon
10
- from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
- from .qk_clip import compute_scales
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # ======================================================================
16
- # Stage helpers
17
- # ======================================================================
18
-
19
-
20
- def _launch_gather(
21
- params: list[DTensor],
22
- owned_params: list[DTensor],
23
- param_to_state: dict[int, _muon_state],
24
- rank: int,
25
- num_ranks: int,
26
- process_group: dist.ProcessGroup,
27
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
- """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
-
30
- Returns:
31
- work: Async operation handle.
32
- recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
- gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
- ``None`` for non-owned.
35
- recv_counts: Per-source-rank element counts.
36
- """
37
- # Allocate gathered-grad buffers
38
- gathered_grads: dict[int, torch.Tensor | None] = {}
39
- for p in params:
40
- state = param_to_state[id(p)]
41
- if rank == state.worker_rank:
42
- gathered_grads[id(p)] = torch.empty(p.shape,
43
- dtype=COMM_DTYPE,
44
- device="cuda")
45
- else:
46
- gathered_grads[id(p)] = None
47
-
48
- # Build send buffer – batch grad copies via torch.cat
49
- # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
- send_counts = [0] * num_ranks
51
- for p in params:
52
- state = param_to_state[id(p)]
53
- send_counts[state.worker_rank] += state.rank_numels[rank]
54
-
55
- total_send = sum(send_counts)
56
- if total_send > 0:
57
- # Group grad slices by destination rank in a single pass.
58
- dst_to_grads = [[] for _ in range(num_ranks)]
59
- for p in params:
60
- state = param_to_state[id(p)]
61
- n = state.rank_numels[rank]
62
- if n > 0:
63
- g = p.grad.to_local()
64
- dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
-
66
- # Flatten in dst order and cat once.
67
- all_slices = []
68
- for dst in range(num_ranks):
69
- all_slices.extend(dst_to_grads[dst])
70
- send_buf = torch.cat(all_slices)
71
- if send_buf.dtype != COMM_DTYPE:
72
- send_buf = send_buf.to(COMM_DTYPE)
73
- else:
74
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
-
76
- # Build recv buffer
77
- recv_counts = [0] * num_ranks
78
- for src in range(num_ranks):
79
- total = 0
80
- for p in owned_params:
81
- state = param_to_state[id(p)]
82
- assert state.worker_rank == rank
83
- total += state.rank_numels[src]
84
- recv_counts[src] = total
85
-
86
- recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
87
-
88
- # Launch async all-to-all
89
- logger.debug(f"send_buf size: {send_buf.numel()}, "
90
- f"recv_buf size: {recv_buf.numel()}, "
91
- f"recv_counts: {recv_counts}, "
92
- f"send_counts: {send_counts}, "
93
- f"process_group: {str(process_group)}")
94
- work = dist.all_to_all_single(
95
- recv_buf,
96
- send_buf,
97
- output_split_sizes=recv_counts,
98
- input_split_sizes=send_counts,
99
- group=process_group,
100
- async_op=True,
101
- )
102
-
103
- return work, recv_buf, gathered_grads, recv_counts
104
-
105
-
106
- def _complete_gather(
107
- recv_buf: torch.Tensor,
108
- recv_counts: list[int],
109
- owned_params: list[DTensor],
110
- gathered_grads: dict[int, torch.Tensor | None],
111
- param_to_state: dict[int, _muon_state],
112
- rank: int,
113
- ) -> None:
114
- """Reconstruct gathered grads from the recv buffer (in-place)."""
115
- off = 0
116
- for src in range(len(recv_counts)):
117
- if recv_counts[src] == 0:
118
- continue
119
-
120
- block = recv_counts[src]
121
- inner_off = 0
122
- for p in owned_params:
123
- state = param_to_state[id(p)]
124
- assert state.worker_rank == rank
125
-
126
- indices = state.rank_indices[src]
127
-
128
- shard_view = gathered_grads[id(p)][indices]
129
- n = shard_view.numel()
130
- if n == 0:
131
- continue
132
-
133
- sg = recv_buf.narrow(0, off + inner_off, n)
134
- sg = sg.reshape(shard_view.shape)
135
- gathered_grads[id(p)][indices] = sg
136
-
137
- inner_off += n
138
- assert inner_off == block
139
- off += block
140
-
141
-
142
- def _compute_ns(
143
- owned_params: list[DTensor],
144
- gathered_grads: dict[int, torch.Tensor | None],
145
- ns_steps: int,
146
- ) -> dict[int, torch.Tensor | None]:
147
- """Run Newton-Schulz orthogonalization on owned parameters.
148
-
149
- Returns:
150
- computed_us: ``{id(p): orthogonalized_update}`` for owned params.
151
- """
152
- computed_us: dict[int, torch.Tensor | None] = {}
153
- for p in owned_params:
154
- u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
- gathered_grads[id(p)] = None # free gathered grad
156
- computed_us[id(p)] = u
157
- return computed_us
158
-
159
-
160
- def _launch_scatter(
161
- params: list[DTensor],
162
- owned_params: list[DTensor],
163
- param_to_state: dict[int, _muon_state],
164
- rank: int,
165
- num_ranks: int,
166
- process_group: dist.ProcessGroup,
167
- computed_us: dict[int, torch.Tensor | None],
168
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
169
- """Allocate scatter buffers, build send/recv, and launch async all-to-all.
170
-
171
- Returns:
172
- work: Async operation handle.
173
- recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
- scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
- zero-copy views into ``recv_buf``.
176
- recv_counts: Per-source-rank element counts.
177
- """
178
- # scattered_us is populated by _complete_scatter with zero-copy views
179
- # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
- # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
- # so _update_params can iterate all params without KeyError.
182
- scattered_us: dict[int, torch.Tensor] = {}
183
- for p in params:
184
- if param_to_state[id(p)].rank_numels[rank] == 0:
185
- scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
- dtype=COMM_DTYPE)
187
-
188
- # Build send buffer – batch via torch.cat
189
- # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
- send_counts = [0] * num_ranks
191
- if owned_params:
192
- for p in owned_params:
193
- state = param_to_state[id(p)]
194
- for dst_rank in range(num_ranks):
195
- send_counts[dst_rank] += state.rank_numels[dst_rank]
196
-
197
- total_send = sum(send_counts)
198
- if total_send > 0:
199
- # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
- u_fulls = {}
201
- for p in owned_params:
202
- u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
-
204
- # Collect slices in dst order (matches all-to-all send layout).
205
- all_slices = []
206
- for dst_rank in range(num_ranks):
207
- for p in owned_params:
208
- state = param_to_state[id(p)]
209
- su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
- if su.numel() > 0:
211
- all_slices.append(su)
212
-
213
- send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
- 0, dtype=COMM_DTYPE, device="cuda")
215
- else:
216
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
-
218
- # Build recv buffer
219
- recv_counts = [0] * num_ranks
220
- for src in range(num_ranks):
221
- total = 0
222
- for p in params:
223
- state = param_to_state[id(p)]
224
- if state.worker_rank != src:
225
- continue
226
- total += state.rank_numels[rank]
227
- recv_counts[src] = total
228
-
229
- recv_total = sum(recv_counts)
230
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
-
232
- # Launch async all-to-all
233
- work = dist.all_to_all_single(
234
- recv_buf,
235
- send_buf,
236
- output_split_sizes=recv_counts,
237
- input_split_sizes=send_counts,
238
- group=process_group,
239
- async_op=True,
240
- )
241
-
242
- return work, recv_buf, scattered_us, recv_counts
243
-
244
-
245
- def _complete_scatter(
246
- recv_buf: torch.Tensor,
247
- recv_counts: list[int],
248
- params: list[DTensor],
249
- param_to_state: dict[int, _muon_state],
250
- rank: int,
251
- scattered_us: dict[int, torch.Tensor],
252
- ) -> None:
253
- """Populate scattered_us with zero-copy views into recv_buf.
254
-
255
- Instead of pre-allocating tensors and copying, we assign views directly
256
- from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
- The underlying storage of ``recv_buf`` is kept alive through the views
258
- until ``scattered_us`` is cleared after ``_update_params``.
259
- """
260
- off = 0
261
- for src in range(len(recv_counts)):
262
- block = recv_counts[src]
263
- if block == 0:
264
- continue
265
-
266
- inner_off = 0
267
- for p in params:
268
- state = param_to_state[id(p)]
269
- if state.worker_rank != src:
270
- continue
271
- n = state.rank_numels[rank]
272
- if n == 0:
273
- continue
274
-
275
- scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
- n).view_as(p.to_local())
277
-
278
- inner_off += n
279
-
280
- assert inner_off == block
281
- off += block
282
-
283
-
284
- def _update_params(
285
- params: list[DTensor],
286
- param_to_state: dict[int, _muon_state],
287
- rank: int,
288
- scattered_us: dict[int, torch.Tensor],
289
- lr: float,
290
- weight_decay: float,
291
- ) -> None:
292
- """Apply weight decay, Muon update, and optional QK clipping.
293
-
294
- Uses batched ``_foreach_mul_`` for weight decay and batched
295
- ``_foreach_add_`` for the Muon update, grouping parameters by
296
- adjusted_lr to minimize kernel launches while preserving float32
297
- precision for the alpha scaling.
298
- """
299
- if not params:
300
- return
301
-
302
- # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
- p_locals = [p._local_tensor for p in params]
304
- torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
-
306
- # Group params by adjusted_lr so _foreach_add_ can use a single
307
- # alpha per group (preserves float32 precision for alpha scaling).
308
- lr_groups: dict[float, tuple[list, list]] = {}
309
- for p in params:
310
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
- if adjusted_lr not in lr_groups:
312
- lr_groups[adjusted_lr] = ([], [])
313
- lr_groups[adjusted_lr][0].append(p._local_tensor)
314
- lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
-
316
- for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
- torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
-
319
- # QK clipping – applied directly on the local tensor to
320
- # avoid DTensor sharding-propagation issues with _StridedShard.
321
- for p in params:
322
- state = param_to_state[id(p)]
323
- if state.qk_clip_state is None:
324
- continue
325
- scales_full = compute_scales(p, state.qk_clip_state)
326
- if scales_full is not None:
327
- ratio = p.shape[0] // scales_full.shape[0]
328
- idx0 = state.rank_indices[rank][0]
329
- if isinstance(idx0, slice):
330
- start = idx0.start or 0
331
- idx0 = torch.arange(start,
332
- idx0.stop,
333
- device=scales_full.device)
334
- row_scales = scales_full[idx0 // ratio]
335
- p._local_tensor.mul_(row_scales.view(-1, 1))
336
-
337
-
338
- # ======================================================================
339
- # Pre-launch helper for overlapping first chunk's gather with other work.
340
- # ======================================================================
341
-
342
-
343
- @torch.no_grad()
344
- def prelaunch_first_gather(
345
- params: list[DTensor],
346
- param_to_state: dict[int, _muon_state],
347
- rank: int,
348
- none_grad: bool,
349
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
- """Launch the first chunk's A2A gather early for overlap with other compute.
351
-
352
- Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
- the NCCL all-to-all runs concurrently on the NCCL stream while the
354
- default stream executes compute.
355
-
356
- Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
- be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
- """
359
- process_group = param_to_state[id(params[0])].process_group
360
- num_ranks = dist.get_world_size(group=process_group)
361
- owned_params = [
362
- p for p in params if param_to_state[id(p)].worker_rank == rank
363
- ]
364
-
365
- with record_function("muon::prelaunch_gather"):
366
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
- params, owned_params, param_to_state, rank, num_ranks,
368
- process_group)
369
-
370
- if none_grad:
371
- for p in params:
372
- p.grad = None
373
-
374
- return work, recv_buf, gathered_grads, recv_counts
375
-
376
-
377
- # ======================================================================
378
- # Main generator – thin orchestrator that wires stages together.
379
- # ======================================================================
380
-
381
-
382
- @torch.no_grad()
383
- def muon_chunk_pipeline(
384
- params: list[DTensor],
385
- param_to_state: dict[int, _muon_state],
386
- rank: int,
387
- ns_steps: int,
388
- lr: float,
389
- weight_decay: float,
390
- none_grad: bool,
391
- prelaunch_gather: tuple | None = None,
392
- ) -> Generator[None, None, None]:
393
- """Process one chunk of parameters through the full Muon pipeline.
394
-
395
- Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
396
-
397
- Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
398
- that communication and computation overlap across chunks. Async
399
- communication is launched via ``async_op=True`` and completed after
400
- the yield with ``work.wait()``.
401
-
402
- Overlap happens because :func:`run_pipeline` admits one new chunk
403
- per iteration (staggered admission). While chunk *N* does NS
404
- compute on the default CUDA stream, chunk *N+1*'s async all-to-all
405
- runs concurrently on the NCCL stream — no separate ``comm_stream``
406
- is required.
407
-
408
- If ``prelaunch_gather`` is provided, the gather was already launched
409
- by :func:`prelaunch_first_gather` and we skip launching it again.
410
-
411
- Yields exactly **2** times:
412
-
413
- 1. After launching async all-to-all gather (or immediately if pre-launched).
414
- 2. After launching async all-to-all scatter.
415
- """
416
- process_group = param_to_state[id(params[0])].process_group
417
- num_ranks = dist.get_world_size(group=process_group)
418
- owned_params = [
419
- p for p in params if param_to_state[id(p)].worker_rank == rank
420
- ]
421
-
422
- if prelaunch_gather is not None:
423
- # Gather was pre-launched; none_grad already handled by caller.
424
- work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
- else:
426
- # Normal path: launch async gather.
427
- with record_function("muon::launch_gather"):
428
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
- params, owned_params, param_to_state, rank, num_ranks,
430
- process_group)
431
-
432
- if none_grad:
433
- for p in params:
434
- p.grad = None
435
-
436
- yield # --- YIELD 1: other chunks can launch their gather ---
437
-
438
- with record_function("muon::wait_gather"):
439
- work.wait()
440
- _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
441
- param_to_state, rank)
442
- del recv_buf
443
-
444
- # Stage 3: Newton-Schulz orthogonalization.
445
- with record_function("muon::newton_schulz"):
446
- computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
447
- gathered_grads.clear()
448
-
449
- # Stages 4-5: launch async scatter.
450
- with record_function("muon::launch_scatter"):
451
- work, recv_buf, scattered_us, recv_counts = _launch_scatter(
452
- params, owned_params, param_to_state, rank, num_ranks,
453
- process_group, computed_us)
454
- computed_us.clear()
455
-
456
- yield # --- YIELD 2: other chunks can launch their scatter ---
457
-
458
- with record_function("muon::wait_scatter"):
459
- work.wait()
460
- _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
461
- scattered_us)
462
- del recv_buf
463
-
464
- # Stage 6: apply parameter updates.
465
- with record_function("muon::update_params"):
466
- _update_params(params, param_to_state, rank, scattered_us, lr,
467
- weight_decay)
468
- scattered_us.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py DELETED
@@ -1,198 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
-
8
- from .core import normalize_fqn
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
- """
15
- Parse a parameter name to check if it is a query/key projection layer
16
- and return (kind, layer_index).
17
-
18
- Supported kinds:
19
- MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
- MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
-
22
- Returns:
23
- (kind, layer_idx) or (None, -1) if not matched.
24
-
25
- Example:
26
- 'model.3.attn.wq.weight' -> ('wq', 3)
27
- 'model.5.attn.wk.weight' -> ('wk', 5)
28
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
- 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
- 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
- 'model.4.attn.v_proj.weight' -> (None, -1)
33
- """
34
- parts = normalize_fqn(name).split('.')
35
- if len(parts) < 3:
36
- return None, -1
37
-
38
- kind = parts[-2]
39
-
40
- layer_idx = -1
41
- for part in reversed(parts):
42
- if part.isdigit():
43
- layer_idx = int(part)
44
- break
45
-
46
- if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
- return kind, layer_idx
48
-
49
- return None, -1
50
-
51
-
52
- @dataclass
53
- class QKClipInfo:
54
- """Per-parameter dynamic info computed from config + runtime logits."""
55
- kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
- indices: list[int] # which heads to consider for clipping
57
- head_dim: int # from config (qk_head_dim for MLA wq_b)
58
- threshold: float # from config
59
- logit: torch.Tensor | None
60
-
61
- # MLA-specific fields
62
- is_mla: bool = False
63
- qk_nope_head_dim: int = 0
64
- qk_rope_head_dim: int = 0
65
- v_head_dim: int = 0
66
-
67
-
68
- def get_qk_clip_info(clip_config, n, qk_logits):
69
- """Extract QK clipping info for a named parameter.
70
-
71
- Args:
72
- clip_config: QK clipping configuration dict (or None).
73
- MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
- MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
- n: Parameter name string.
76
- qk_logits: Dict mapping layer indices to logit tensors (or None).
77
-
78
- Returns:
79
- QKClipInfo instance with clipping configuration for this parameter.
80
- """
81
- if clip_config is None:
82
- return None
83
-
84
- head_dim = clip_config.get('head_dim')
85
- threshold = clip_config.get('threshold')
86
- kind, layer_idx = parse_qk_layer(n)
87
- is_mla = clip_config.get('is_mla', False)
88
-
89
- logit, indices = None, []
90
- if qk_logits is not None and kind is not None:
91
- logit = qk_logits[layer_idx]
92
- if isinstance(logit, DTensor):
93
- # In TP settings, qk_logits may be DTensor
94
- # We convert it to full tensor here for simplicity
95
- logit = logit.full_tensor()
96
-
97
- if kind in ('wq_b', 'wq', 'q_proj'):
98
- indices = clip_config.get('q_indices', []) or []
99
- elif kind in ('wkv_b', 'wk', 'k_proj'):
100
- indices = clip_config.get('k_indices', []) or []
101
-
102
- if is_mla:
103
- return QKClipInfo(
104
- kind=kind,
105
- indices=indices,
106
- head_dim=head_dim,
107
- threshold=threshold,
108
- logit=logit,
109
- is_mla=True,
110
- qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
- qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
- v_head_dim=clip_config['v_head_dim'],
113
- )
114
- else:
115
- return QKClipInfo(
116
- kind=kind,
117
- indices=indices,
118
- head_dim=head_dim,
119
- threshold=threshold,
120
- logit=logit,
121
- )
122
-
123
-
124
- def compute_scales(p, qk_clip_state):
125
- """Compute per-head scaling factors for QK clipping.
126
-
127
- Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
- For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
- """
130
- kind = qk_clip_state.kind
131
- indices = qk_clip_state.indices
132
- head_dim = qk_clip_state.head_dim
133
- threshold = qk_clip_state.threshold
134
- logit = qk_clip_state.logit
135
-
136
- # Check if any head exceeds threshold before allocating.
137
- head_scales = {}
138
- for logit_idx, head_idx in enumerate(indices):
139
- v_ele = float(logit[logit_idx])
140
- if v_ele > threshold:
141
- new_scale = math.sqrt(threshold / v_ele)
142
- if head_idx not in head_scales or new_scale < head_scales[head_idx]:
143
- head_scales[head_idx] = new_scale
144
- logger.info(
145
- f"[{kind}] Head {head_idx} exceeded threshold "
146
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
147
- )
148
-
149
- if not head_scales:
150
- return None
151
-
152
- # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
- if qk_clip_state.is_mla and kind == 'wkv_b':
154
- effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
- else:
156
- effective_head_dim = head_dim
157
-
158
- H_global = p.shape[0] // effective_head_dim
159
- scales_full = torch.ones(H_global, device=p.data.device)
160
- for head_idx, scale in head_scales.items():
161
- scales_full[head_idx] = scale
162
- return scales_full
163
-
164
-
165
- def qk_clip(p, scales, info):
166
- """Apply per-head scaling to a Q/K projection weight matrix.
167
-
168
- Args:
169
- p: Parameter (nn.Parameter or raw tensor).
170
- scales: [n_heads] tensor, each element = √γ_h.
171
- info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
-
173
- MLA sub-region scaling per Algorithm 1 (MuonClip):
174
- wq_b: q_nope rows → √γ, q_pe rows → γ
175
- wkv_b: k_nope rows → √γ, v rows → unchanged
176
- """
177
- W = p.data if isinstance(p, torch.nn.Parameter) else p
178
-
179
- if not info.is_mla:
180
- # MHA/GQA: uniform √γ applied to all rows in each head
181
- W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
- return
183
-
184
- # MLA: vectorized sub-region scaling within each head
185
- if info.kind == 'wq_b':
186
- qk_nope = info.qk_nope_head_dim
187
- qk_head_dim = qk_nope + info.qk_rope_head_dim
188
- W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ
190
- W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
- 1)) # q_pe → γ
192
-
193
- elif info.kind == 'wkv_b':
194
- qk_nope = info.qk_nope_head_dim
195
- kv_stride = qk_nope + info.v_head_dim
196
- W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ
198
- # v rows: not touched (k_R shared rotary unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/adamw.py DELETED
@@ -1,271 +0,0 @@
1
- import logging
2
- from collections import defaultdict
3
- from typing import cast
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def fused_adamw(
13
- params: list[torch.Tensor],
14
- grads: list[torch.Tensor],
15
- exp_avgs: list[torch.Tensor],
16
- exp_avg_sqs: list[torch.Tensor],
17
- max_exp_avg_sqs: list[torch.Tensor],
18
- state_steps: list[torch.Tensor],
19
- amsgrad: bool,
20
- beta1: float,
21
- beta2: float,
22
- lr: float | torch.Tensor,
23
- weight_decay: float,
24
- eps: float,
25
- maximize: bool,
26
- ) -> None:
27
- if not params:
28
- return
29
-
30
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
31
- # treating it as a scalar.
32
- lr_dict: dict | None = ({
33
- lr.device: lr
34
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
35
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
36
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
37
- state_steps] # type: ignore[list-item]
38
- )
39
- for (device, _), (
40
- (
41
- device_params_,
42
- device_grads_,
43
- device_exp_avgs_,
44
- device_exp_avg_sqs_,
45
- device_max_exp_avg_sqs,
46
- device_state_steps_,
47
- ),
48
- _,
49
- ) in grouped_tensors.items():
50
- device_params = cast(list[torch.Tensor], device_params_)
51
- device_grads = cast(list[torch.Tensor], device_grads_)
52
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
53
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
54
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
55
-
56
- if lr_dict is not None and device not in lr_dict:
57
- lr_dict[device] = lr.to(
58
- device=device, non_blocking=True) # type: ignore[union-attr]
59
- lr = lr_dict[device]
60
- torch._foreach_add_(device_state_steps, 1)
61
- func = torch._fused_adamw_
62
- func(
63
- device_params,
64
- device_grads,
65
- device_exp_avgs,
66
- device_exp_avg_sqs,
67
- device_max_exp_avg_sqs, # type: ignore[arg-type]
68
- device_state_steps,
69
- amsgrad=amsgrad,
70
- lr=lr, # type: ignore[arg-type]
71
- beta1=beta1,
72
- beta2=beta2,
73
- weight_decay=weight_decay,
74
- eps=eps,
75
- maximize=maximize,
76
- )
77
-
78
-
79
- def _to_local(t):
80
- """Unwrap DTensor to local tensor for fused ops."""
81
- return t._local_tensor if isinstance(t, DTensor) else t
82
-
83
-
84
- # ---------------------------------------------------------------------------
85
- # Caches for eliminating per-step Python overhead.
86
- #
87
- # Placement grouping and tensor list assembly are identical every step
88
- # (params don't change placement, moment/step tensors are the same objects
89
- # after initialisation). We cache them keyed by id() of the param list
90
- # stored in param_groups (stable across steps).
91
- #
92
- # Only gradients change each step and must be collected fresh.
93
- # ---------------------------------------------------------------------------
94
-
95
- # id(group["params"]) → dict[placement_key, list[param]]
96
- _placement_cache: dict[int, dict[tuple, list]] = {}
97
-
98
- # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
- _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
-
101
-
102
- def _step_adamw_params_slow(optimizer_state, params, group):
103
- """Uncached fallback for the rare case where some params lack grads."""
104
- params_with_grads = []
105
- grads = []
106
- moment1 = []
107
- moment2 = []
108
- state_steps = []
109
-
110
- for p in params:
111
- g = p.grad
112
- if g is None:
113
- continue
114
- state = optimizer_state[p]
115
- params_with_grads.append(_to_local(p))
116
- grads.append(_to_local(g))
117
- if "step" not in state:
118
- state["step"] = torch.zeros((),
119
- dtype=torch.float32,
120
- device=p.device)
121
- state["moment1"] = torch.zeros_like(g)
122
- state["moment2"] = torch.zeros_like(g)
123
- moment1.append(_to_local(state["moment1"]))
124
- moment2.append(_to_local(state["moment2"]))
125
- if not isinstance(state["step"], torch.Tensor):
126
- state["step"] = torch.tensor(state["step"],
127
- dtype=torch.float32,
128
- device=p.device)
129
- state_steps.append(state["step"])
130
-
131
- if not params_with_grads:
132
- return
133
-
134
- lr = group["lr"]
135
- beta1, beta2 = group["adamw_betas"]
136
- eps = group["adamw_eps"]
137
- weight_decay = group["weight_decay"]
138
-
139
- fused_adamw(
140
- params_with_grads,
141
- grads,
142
- moment1,
143
- moment2,
144
- [],
145
- state_steps,
146
- amsgrad=False,
147
- beta1=beta1,
148
- beta2=beta2,
149
- lr=lr,
150
- weight_decay=weight_decay,
151
- eps=eps,
152
- maximize=False,
153
- )
154
-
155
-
156
- def step_adamw_params(optimizer_state, params, group):
157
- """Run fused AdamW on a list of parameters sharing the same placement.
158
-
159
- After the first call, cached tensor lists (params_local, moment1,
160
- moment2, state_steps) are reused — only gradients are collected fresh.
161
-
162
- Args:
163
- optimizer_state: The optimizer's state dict (self.state in Muon).
164
- params: List of parameters to update.
165
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
- """
167
- # Collect grads — the only thing that changes each step.
168
- with record_function("adamw::collect_grads"):
169
- grads = []
170
- for p in params:
171
- g = p.grad
172
- if g is None:
173
- # Rare: fall back to slow path that filters per-param.
174
- _step_adamw_params_slow(optimizer_state, params, group)
175
- return
176
- grads.append(_to_local(g))
177
-
178
- tensor_key = id(params)
179
- if tensor_key not in _tensor_cache:
180
- with record_function("adamw::init_tensor_cache"):
181
- params_local = []
182
- moment1 = []
183
- moment2 = []
184
- state_steps = []
185
-
186
- for p in params:
187
- state = optimizer_state[p]
188
- params_local.append(_to_local(p))
189
- if "step" not in state:
190
- state["step"] = torch.zeros((),
191
- dtype=torch.float32,
192
- device=p.device)
193
- state["moment1"] = torch.zeros_like(p.grad)
194
- state["moment2"] = torch.zeros_like(p.grad)
195
- moment1.append(_to_local(state["moment1"]))
196
- moment2.append(_to_local(state["moment2"]))
197
- if not isinstance(state["step"], torch.Tensor):
198
- state["step"] = torch.tensor(state["step"],
199
- dtype=torch.float32,
200
- device=p.device)
201
- state_steps.append(state["step"])
202
-
203
- _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
- state_steps)
205
-
206
- params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
-
208
- lr = group["lr"]
209
- beta1, beta2 = group["adamw_betas"]
210
- eps = group["adamw_eps"]
211
- weight_decay = group["weight_decay"]
212
-
213
- with record_function("adamw::fused_adamw"):
214
- fused_adamw(
215
- params_local,
216
- grads,
217
- moment1,
218
- moment2,
219
- [],
220
- state_steps,
221
- amsgrad=False,
222
- beta1=beta1,
223
- beta2=beta2,
224
- lr=lr,
225
- weight_decay=weight_decay,
226
- eps=eps,
227
- maximize=False,
228
- )
229
-
230
-
231
- def step_adamw(optimizer_state, group):
232
- """Dispatch AdamW step, grouping parameters by type and placement.
233
-
234
- Placement grouping is cached after the first call since params never
235
- change their placement between steps.
236
-
237
- Args:
238
- optimizer_state: The optimizer's state dict (self.state in Muon).
239
- group: Parameter group dict.
240
- """
241
- params = group["params"]
242
- placement_key = id(params)
243
-
244
- if placement_key not in _placement_cache:
245
- with record_function("adamw::group_by_placement"):
246
- placement_to_params: dict[tuple,
247
- list[torch.Tensor]] = defaultdict(list)
248
- for p in params:
249
- match p:
250
- case DTensor():
251
- logger.debug(
252
- "[AdamW] DTensor param: shape=%s, placements=%s, "
253
- "mesh=%s, grad=%s", p.shape, p.placements,
254
- p.device_mesh.mesh_dim_names,
255
- p.grad.shape if p.grad is not None else None)
256
- placement_to_params[tuple(
257
- [p.placements, p.device_mesh])].append(p)
258
- case torch.Tensor():
259
- logger.debug(
260
- "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
- p.grad.shape if p.grad is not None else None)
262
- placement_to_params[tuple([torch.Tensor,
263
- None])].append(p)
264
-
265
- logger.debug("[AdamW] %d placement groups, %d total params",
266
- len(placement_to_params), len(params))
267
-
268
- _placement_cache[placement_key] = dict(placement_to_params)
269
-
270
- for group_params in _placement_cache[placement_key].values():
271
- step_adamw_params(optimizer_state, group_params, group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/async_utils.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- logger = logging.getLogger(__name__)
5
-
6
-
7
- class _Task:
8
- """Internal: wraps a generator, advances one yield at a time."""
9
-
10
- def __init__(self, generator: Generator[None, None, None], index: int):
11
- self._generator = generator
12
- self._index = index
13
- self._steps_completed = 0
14
- self.step() # run to first yield
15
-
16
- def step(self) -> bool:
17
- try:
18
- next(self._generator)
19
- self._steps_completed += 1
20
- logger.debug("pipeline[%d] completed stage %d", self._index,
21
- self._steps_completed)
22
- return True
23
- except StopIteration:
24
- logger.debug("pipeline[%d] finished after %d stages", self._index,
25
- self._steps_completed)
26
- return False
27
-
28
- def close(self):
29
- self._generator.close()
30
-
31
-
32
- def run_pipeline(
33
- pipelines: Generator[Generator[None, None, None], None, None],
34
- max_concurrent: int,
35
- ) -> None:
36
- """Run generator-based pipelines with bounded concurrency.
37
-
38
- Each pipeline is a generator that yields at stage boundaries.
39
- The runtime interleaves pipelines so communication and computation
40
- overlap across chunks.
41
- """
42
- if max_concurrent <= 0:
43
- raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
-
45
- have_new = True
46
- task_index = 0
47
- previous_tasks: list[_Task] = []
48
-
49
- try:
50
- while have_new or previous_tasks:
51
- running_tasks: list[_Task] = []
52
-
53
- # Admit one new pipeline per iteration (staggered admission).
54
- # Admitting one at a time ensures that while chunk N does NS
55
- # compute on the default stream, chunk N+1's NCCL all-to-all
56
- # runs concurrently on the NCCL stream — creating real
57
- # communication/computation overlap on the GPU.
58
- if have_new and len(previous_tasks) < max_concurrent:
59
- try:
60
- gen = next(pipelines)
61
- task = _Task(gen, task_index)
62
- task_index += 1
63
- running_tasks.append(task)
64
- except StopIteration:
65
- have_new = False
66
-
67
- # Advance every previously-yielded task by one step.
68
- for task in previous_tasks:
69
- if task.step():
70
- running_tasks.append(task)
71
-
72
- previous_tasks = running_tasks
73
- except BaseException:
74
- # Clean up all in-flight generators to release GPU resources.
75
- for task in previous_tasks:
76
- task.close()
77
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/core.py DELETED
@@ -1,219 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
- from typing import List
5
-
6
- import torch
7
- from torch.distributed import ProcessGroup
8
- from torch.distributed.tensor import DTensor
9
-
10
- # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
- # parameter FQNs. Activation checkpointing similarly inserts
12
- # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
- # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
- _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def normalize_fqn(name: str) -> str:
20
- """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
- return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
-
23
-
24
- @dataclass
25
- class _muon_state:
26
- worker_rank: int
27
- process_group: ProcessGroup
28
- rank_indices: dict[int, tuple] # local_rank -> per-dim indices
29
- rank_numels: dict[int, int] # local_rank -> numel
30
- name: str
31
- qk_clip_state: torch.Tensor | None = None
32
-
33
-
34
- def _batch_momentum(
35
- grads: List[torch.Tensor],
36
- momentum_bufs: List[torch.Tensor],
37
- momentum: torch.Tensor,
38
- ) -> None:
39
- """Batched momentum update (no nesterov)."""
40
- torch._foreach_mul_(momentum_bufs, momentum)
41
- torch._foreach_add_(momentum_bufs, grads)
42
-
43
-
44
- def _batch_momentum_nesterov(
45
- grads: List[torch.Tensor],
46
- momentum_bufs: List[torch.Tensor],
47
- momentum: torch.Tensor,
48
- ) -> None:
49
- """Batched momentum update with nesterov correction."""
50
- torch._foreach_mul_(momentum_bufs, momentum)
51
- torch._foreach_add_(momentum_bufs, grads)
52
- nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
- torch._foreach_add_(grads, nesterov_terms)
54
-
55
-
56
- _compiled_momentum: dict[bool, callable] = {}
57
- _use_momentum_compile = True
58
-
59
-
60
- def set_momentum_compile(enabled: bool):
61
- """Toggle torch.compile for batched momentum."""
62
- global _use_momentum_compile
63
- _use_momentum_compile = enabled
64
-
65
-
66
- def batch_pre_ortho(
67
- grads: List[torch.Tensor],
68
- momentum_bufs: List[torch.Tensor],
69
- momentum: torch.Tensor,
70
- nesterov: bool,
71
- ) -> None:
72
- """Batched momentum update on lists of plain tensors.
73
-
74
- Mirrors dion's ``muon_update_pre_orthogonalize``.
75
- Inputs must be plain CUDA tensors (not DTensor).
76
- Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
-
78
- When compile is enabled, uses separately compiled functions for
79
- nesterov=True/False to avoid graph breaks from the branch.
80
- """
81
- fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
- if _use_momentum_compile:
83
- if nesterov not in _compiled_momentum:
84
- _compiled_momentum[nesterov] = torch.compile(fn)
85
- fn = _compiled_momentum[nesterov]
86
- fn(grads, momentum_bufs, momentum)
87
-
88
-
89
- def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
- """Weight-decay + update on plain tensors.
91
-
92
- Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
- lookup per call × 256+ params = massive overhead. The pipeline path uses
94
- batched _foreach_* ops instead; this function remains for base() and
95
- distributed_muon().
96
- """
97
- p_data.mul_(1 - lr * weight_decay)
98
- p_data.add_(u_data, alpha=-adjusted_lr)
99
-
100
-
101
- def update_p(p, u, lr, adjusted_lr, weight_decay):
102
- """Apply weight decay and orthogonalized update to parameter.
103
-
104
- Args:
105
- p: Parameter (torch.nn.Parameter or DTensor).
106
- u: Orthogonalized update tensor.
107
- lr: Base learning rate.
108
- adjusted_lr: Size-adjusted learning rate.
109
- weight_decay: Weight decay coefficient.
110
- """
111
- # Unwrap Parameter -> underlying data tensor.
112
- p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
- # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
- if isinstance(p_data, DTensor):
115
- p_data = p_data._local_tensor
116
- u_data = u._local_tensor if isinstance(u, DTensor) else u
117
- _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
118
-
119
-
120
- def adjust_lr_for_muon(lr, param_shape):
121
- """Scale learning rate based on parameter matrix dimensions.
122
-
123
- Args:
124
- lr: Base learning rate.
125
- param_shape: Shape of the parameter tensor.
126
-
127
- Returns:
128
- Adjusted learning rate.
129
- """
130
- A, B = param_shape[:2]
131
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
132
- # as described in the paper
133
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
134
- adjusted_lr = lr * adjusted_ratio
135
- return adjusted_lr
136
-
137
-
138
- def _match_key(parts, key):
139
- """Check if key matches as contiguous components in parts.
140
-
141
- Single-component keys (e.g. "experts") match any single component.
142
- Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
- """
144
- key_parts = key.split(".")
145
- key_len = len(key_parts)
146
- if key_len == 1:
147
- return key in parts
148
- return any(parts[i:i + key_len] == key_parts
149
- for i in range(len(parts) - key_len + 1))
150
-
151
-
152
- def is_expert_param(name, expert_keys):
153
- """Check if a parameter name matches any expert key (component-level)."""
154
- if not expert_keys:
155
- return False
156
- parts = normalize_fqn(name).split(".")
157
- return any(_match_key(parts, key) for key in expert_keys)
158
-
159
-
160
- def default_is_muon(name, x, expert_keys=None):
161
- normalized = normalize_fqn(name)
162
- parts = normalized.split(".")
163
- skip_keys = [
164
- "embed_tokens",
165
- "lm_head",
166
- "tok_embeddings",
167
- "output",
168
- "mhc_attn",
169
- "mhc_ffn",
170
- "lambda_proj",
171
- ]
172
- if any(key in parts for key in skip_keys):
173
- logger.info(
174
- "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
- normalized, name, x.ndim)
176
- return False
177
- effective_ndim = x.ndim
178
- is_expert = is_expert_param(name, expert_keys)
179
- if is_expert:
180
- effective_ndim -= 1
181
- result = effective_ndim >= 2
182
- logger.info(
183
- "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
- normalized, name, x.ndim, is_expert, effective_ndim,
185
- "Muon" if result else "AdamW")
186
- return result
187
-
188
-
189
- def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
190
- if is_muon_func is None:
191
- is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
-
193
- muon_params, muon_names = [], []
194
- non_muon_params, non_muon_names = [], []
195
-
196
- for n, p in model.named_parameters():
197
- if not p.requires_grad:
198
- continue
199
- if is_muon_func(n, p):
200
- muon_params.append(p)
201
- muon_names.append(n)
202
- else:
203
- non_muon_params.append(p)
204
- non_muon_names.append(n)
205
-
206
- logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
- expert_keys, len(muon_names), len(non_muon_names))
208
-
209
- return [
210
- {
211
- "params": muon_params,
212
- "names": muon_names,
213
- "use_muon": True,
214
- },
215
- {
216
- "params": non_muon_params,
217
- "use_muon": False,
218
- },
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py DELETED
@@ -1,206 +0,0 @@
1
- """CPU offloading for optimizer states.
2
-
3
- Manages a pinned CPU memory pool and async CUDA streams to offload
4
- optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
- optimizer steps, freeing GPU memory.
6
-
7
- All tracked tensors are packed into a single flat pinned CPU buffer
8
- (per dtype). D2H and H2D copies are performed per-tensor directly
9
- between individual GPU tensors and their slice of the CPU flat buffer
10
- — no GPU staging buffer is allocated, so there is **no temporary GPU
11
- memory spike** during offload or reload.
12
-
13
- Individual tensor storages are freed after offload via
14
- ``untyped_storage().resize_(0)``, preserving tensor identity so
15
- downstream caches remain valid.
16
- """
17
-
18
- import logging
19
- from collections import defaultdict
20
-
21
- import torch
22
- from torch.distributed.tensor import DTensor
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- class CPUOffloadPool:
28
- """Pinned CPU memory pool for async optimizer state offloading.
29
-
30
- Tracked tensors are grouped by dtype. Each group gets a single flat
31
- pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
- the flat buffer) to avoid allocating a GPU staging buffer.
33
- """
34
-
35
- def __init__(self):
36
- self._managed: list[torch.Tensor] = []
37
- self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
-
39
- # Per-dtype group: populated on first offload.
40
- # dtype → dict with keys:
41
- # "indices" : list[int] managed-list indices
42
- # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
- # "total" : int total numel
44
- # "cpu_flat" : Tensor pinned CPU buffer
45
- self._groups: dict[torch.dtype, dict] = {}
46
-
47
- self._offload_stream: torch.cuda.Stream | None = None
48
- self._device: torch.device | None = None
49
- self._initialized: bool = False
50
- self._logged: bool = False
51
-
52
- # ------------------------------------------------------------------
53
- @staticmethod
54
- def _local(t: torch.Tensor) -> torch.Tensor:
55
- """Unwrap DTensor to its local CUDA tensor."""
56
- return t._local_tensor if isinstance(t, DTensor) else t
57
-
58
- def _ensure_stream(self):
59
- if self._offload_stream is None:
60
- self._offload_stream = torch.cuda.Stream(device=self._device)
61
-
62
- # ------------------------------------------------------------------
63
- def track(self, tensor: torch.Tensor):
64
- """Register a GPU tensor for CPU offloading. Idempotent."""
65
- tid = id(tensor)
66
- if tid in self._storage_nbytes:
67
- return
68
- local = self._local(tensor)
69
- if self._device is None:
70
- self._device = local.device
71
- storage = local.untyped_storage()
72
- # Skip tensors with empty storage (e.g. empty FSDP shards)
73
- if storage.size() == 0:
74
- return
75
- self._storage_nbytes[tid] = storage.size()
76
- self._managed.append(tensor)
77
-
78
- # ------------------------------------------------------------------
79
- def _init_buffers(self):
80
- """Build per-dtype flat buffers on first offload."""
81
- # Group managed tensors by dtype.
82
- dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
83
- for idx, t in enumerate(self._managed):
84
- local = self._local(t)
85
- dtype_map[local.dtype].append((idx, local.numel()))
86
-
87
- total_cpu_bytes = 0
88
- for dtype, entries in dtype_map.items():
89
- offsets: list[tuple[int, int]] = []
90
- indices: list[int] = []
91
- off = 0
92
- for idx, n in entries:
93
- indices.append(idx)
94
- offsets.append((off, n))
95
- off += n
96
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
97
- self._groups[dtype] = {
98
- "indices": indices,
99
- "offsets": offsets,
100
- "total": off,
101
- "cpu_flat": cpu_flat,
102
- }
103
- total_cpu_bytes += off * cpu_flat.element_size()
104
-
105
- self._initialized = True
106
- logger.info(
107
- "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
108
- "%.2f MB pinned CPU memory",
109
- len(self._managed),
110
- len(self._groups),
111
- total_cpu_bytes / (1024**2),
112
- )
113
-
114
- # ------------------------------------------------------------------
115
- def offload(self):
116
- """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
117
- if not self._managed:
118
- return
119
- if not self._initialized:
120
- self._init_buffers()
121
- self._ensure_stream()
122
-
123
- # Offload stream waits for compute to finish.
124
- compute_event = torch.cuda.current_stream(self._device).record_event()
125
- self._offload_stream.wait_event(compute_event)
126
-
127
- offloaded_bytes = 0
128
-
129
- # Per-tensor D2H copies directly into CPU flat buffer slices.
130
- # No GPU staging buffer → no temporary GPU memory spike.
131
- with torch.cuda.stream(self._offload_stream):
132
- for dtype, grp in self._groups.items():
133
- indices = grp["indices"]
134
- offsets = grp["offsets"]
135
- cpu_flat = grp["cpu_flat"]
136
-
137
- for i, mgd_idx in enumerate(indices):
138
- local = self._local(self._managed[mgd_idx])
139
- off, n = offsets[i]
140
- cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
141
-
142
- offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
-
144
- # Wait for all D2H copies to land, then free GPU storage.
145
- self._offload_stream.synchronize()
146
- for t in self._managed:
147
- storage = self._local(t).untyped_storage()
148
- if storage.size() != 0:
149
- storage.resize_(0)
150
- else:
151
- raise RuntimeError(
152
- f"Tensor storage is already freed (size=0) before offload. "
153
- f"This indicates a double-free or external interference. "
154
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
155
- )
156
-
157
- if not self._logged:
158
- logger.info(
159
- "[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
160
- offloaded_bytes / (1024**2),
161
- )
162
-
163
- # ------------------------------------------------------------------
164
- def reload(self):
165
- """Per-tensor H2D from CPU flat buffer on the default stream.
166
-
167
- Runs on the current (default) CUDA stream to avoid stream
168
- interaction issues with the parallel Muon pipeline. Since
169
- pinned CPU memory is the source, the copies overlap with
170
- GPU idle time between steps.
171
- """
172
- if not self._managed or not self._initialized:
173
- return
174
-
175
- reloaded_bytes = 0
176
-
177
- # Re-allocate all GPU storages first.
178
- for t in self._managed:
179
- local = self._local(t)
180
- storage = local.untyped_storage()
181
- if storage.size() != 0:
182
- raise RuntimeError(
183
- f"Storage should have been freed (size=0) before reload, "
184
- f"but got size={storage.size()}. "
185
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
186
- )
187
- storage.resize_(self._storage_nbytes[id(t)])
188
-
189
- # Per-tensor H2D copies from CPU flat buffer slices.
190
- # non_blocking=True with pinned source allows DMA overlap.
191
- for dtype, grp in self._groups.items():
192
- indices = grp["indices"]
193
- offsets = grp["offsets"]
194
- cpu_flat = grp["cpu_flat"]
195
-
196
- for i, mgd_idx in enumerate(indices):
197
- local = self._local(self._managed[mgd_idx])
198
- off, n = offsets[i]
199
- local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
200
-
201
- reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
-
203
- if not self._logged:
204
- logger.info(
205
- "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2)
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def _is_shard(placement: Placement) -> bool:
11
- """Check if a placement is a shard type (Shard or _StridedShard).
12
-
13
- In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
- ``placement.is_shard()`` returns False for _StridedShard. This helper
15
- handles both old and new hierarchies.
16
- """
17
- return isinstance(placement, (Shard, _StridedShard))
18
-
19
-
20
- def get_slices_of_dtensor(
21
- target: DTensor | torch.Tensor,
22
- local_rank: int,
23
- shard_mesh: DeviceMesh,
24
- shard_placements: tuple[Placement],
25
- ) -> tuple[slice | torch.Tensor, ...]:
26
- """
27
- Get per-dimension indices for a given rank's shard of the target tensor.
28
-
29
- Uses ``Shard.local_shard_size_and_offset`` and
30
- ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
- both contiguous and strided (non-contiguous) sharding.
32
-
33
- Args:
34
- target (DTensor | torch.Tensor): The target tensor (for its shape).
35
- local_rank (int): The local rank within the shard group.
36
- shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
- shard_placements (tuple[Placement]): The shard placements.
38
-
39
- Returns:
40
- A tuple of indices (one per tensor dim). Each element is either:
41
- - A ``slice`` (for contiguous or unsharded dims)
42
- - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
- """
44
-
45
- # find the global rank of the local rank in the shard mesh
46
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
47
-
48
- rank_coords = (shard_mesh.mesh == rank).nonzero()
49
-
50
- assert len(rank_coords) == 1
51
- rank_coords = tuple(rank_coords[0].tolist())
52
-
53
- assert len(rank_coords) == len(shard_placements)
54
-
55
- # Track per-shard-dim indices.
56
- # None means "not yet sharded on this dim".
57
- dim_indices: dict[int, torch.Tensor] = {}
58
-
59
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
- # left-to-right sharding. This is ensured by the sorting logic of
61
- # construct_shard_mesh function.
62
- for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
- zip(rank_coords, shard_placements)):
64
- assert _is_shard(placement)
65
-
66
- num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
- shard_dim = placement.dim
68
-
69
- # Current effective size on this dim (may already be sub-sharded)
70
- if shard_dim in dim_indices:
71
- curr_size = len(dim_indices[shard_dim])
72
- else:
73
- curr_size = target.size()[shard_dim]
74
-
75
- # Compute indices for this level of sharding
76
- if isinstance(placement, _StridedShard):
77
- _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
78
- placement,
79
- curr_size,
80
- num_chunks,
81
- rank_coord,
82
- return_first_offset=False)
83
- new_indices = torch.tensor(offsets, dtype=torch.long)
84
- else:
85
- shard_size, offset = Shard.local_shard_size_and_offset(
86
- curr_size, num_chunks, rank_coord)
87
- new_indices = torch.arange(offset,
88
- offset + shard_size,
89
- dtype=torch.long)
90
-
91
- # Compose with previous indices on this dim
92
- if shard_dim in dim_indices:
93
- dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
94
- else:
95
- dim_indices[shard_dim] = new_indices
96
-
97
- # Build result tuple
98
- result: list[slice | torch.Tensor] = []
99
- for d in range(len(target.size())):
100
- if d not in dim_indices:
101
- result.append(slice(None))
102
- else:
103
- indices = dim_indices[d]
104
- # Convert contiguous indices to slice for efficiency
105
- if len(indices) > 0:
106
- start = indices[0].item()
107
- expected = torch.arange(start,
108
- start + len(indices),
109
- dtype=torch.long)
110
- if torch.equal(indices, expected):
111
- result.append(slice(start, start + len(indices)))
112
- else:
113
- result.append(indices)
114
- else:
115
- result.append(slice(0, 0))
116
-
117
- return tuple(result)
118
-
119
-
120
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
121
- ProcessGroup]] = dict()
122
-
123
-
124
- def construct_shard_mesh(
125
- placements: tuple[Placement],
126
- mesh: DeviceMesh,
127
- ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
128
- """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
129
-
130
- Given a DTensor's placements and device mesh, extracts the "shard group"
131
- — the set of ranks that together hold all shards of the same replica —
132
- and creates a ProcessGroup for all-to-all among them.
133
-
134
- Steps:
135
- 1. Sort placements: Replicate first, then Shard by (dim, granularity).
136
- 2. Permute the mesh tensor to match the sorted order.
137
- 3. Collapse Replicate dims → list of shard sub-meshes (one per replica).
138
- 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
139
-
140
- Example — 8 GPUs, mesh shape (2, 2, 2),
141
- placements ``[Shard(0), Replicate, _StridedShard(0)]``::
142
-
143
- Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)]
144
- Permutation: [1, 2, 0]
145
-
146
- Step 2 — Permute mesh dims by [1, 2, 0]:
147
- Original: Permuted:
148
- [[[0,1],[2,3]], [[[0,2],[1,3]],
149
- [[4,5],[6,7]]] [[4,6],[5,7]]]
150
-
151
- Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
152
- sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
153
- sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
154
- shard_placements = (_StridedShard(0), Shard(0))
155
-
156
- Step 4 — Rank 0 → ProcessGroup([0,1,4,5])
157
- Rank 2 → ProcessGroup([2,3,6,7])
158
-
159
- Returns:
160
- ``(shard_mesh, process_group, shard_placements)``
161
- """
162
- my_rank = dist.get_rank()
163
- assert mesh.mesh.device.type == 'cpu'
164
-
165
- # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
- # overhead of dist.new_group(). The standard path below also handles
168
- # subset calls safely via use_local_synchronization=True, but this
169
- # fast path is still beneficial for the common 1D shard case.
170
- if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
- key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
- if key not in _ranks_to_dist_cache:
173
- _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
174
- return (*_ranks_to_dist_cache[key], tuple(placements))
175
-
176
- mesh_tensor = mesh.mesh.clone()
177
-
178
- # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
179
- # _StridedShard comes BEFORE regular Shard on the same dim so that
180
- # get_slices_of_dtensor applies the outer sharding first, matching
181
- # DTensor's left-to-right (outer-to-inner) composition order.
182
- def _sort_key(item):
183
- index, placement = item
184
- assert not placement.is_partial(), "Partial placement not supported"
185
- if placement.is_replicate():
186
- return (-1, 0, index)
187
- assert _is_shard(placement), f"Unsupported: {type(placement)}"
188
- split = (-1 / placement.split_factor if isinstance(
189
- placement, _StridedShard) else 0)
190
- return (placement.dim, split, index)
191
-
192
- indexed = sorted(enumerate(placements), key=_sort_key)
193
- perm, sorted_placements = zip(*indexed)
194
-
195
- # -- Step 2: Permute mesh to match sorted placement order. --------------
196
- sorted_mesh = mesh_tensor.permute(perm)
197
-
198
- # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
199
- # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
200
- num_rep = sum(1 for p in sorted_placements if p.is_replicate())
201
- if num_rep > 0:
202
- if num_rep > 1:
203
- sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
204
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
205
- else:
206
- shard_meshes = [sorted_mesh]
207
- shard_placements = sorted_placements[num_rep:]
208
- assert len(shard_placements) == len(set(shard_placements))
209
-
210
- # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
- # Each rank only creates the group it belongs to, using
212
- # use_local_synchronization=True so that only group members need to
213
- # coordinate. This avoids deadlocks when different PP stages call
214
- # construct_shard_mesh for different parameters.
215
- def _cache_key(t: torch.Tensor) -> tuple:
216
- return (*t.shape, *t.flatten().tolist())
217
-
218
- my_key = None
219
- for sm in shard_meshes:
220
- if (my_rank == sm).any().item():
221
- key = _cache_key(sm)
222
- assert my_key is None, "Rank appears in multiple shard groups"
223
- my_key = key
224
- if key not in _ranks_to_dist_cache:
225
- pg = dist.new_group(sm.flatten().tolist(),
226
- use_local_synchronization=True)
227
- _ranks_to_dist_cache[key] = (
228
- DeviceMesh(device_type="cuda", mesh=sm),
229
- pg,
230
- )
231
-
232
- return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,122 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- restore_value=['y'],
47
- )
48
- @triton.jit
49
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
50
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
51
- GROUP_SIZE_M: tl.constexpr):
52
- """
53
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
54
- The code is a simple adaptation from the triton `matmul` tutorial:
55
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
56
- """
57
- pid = tl.program_id(axis=0)
58
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
60
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
61
- group_id = pid // num_pid_in_group
62
- first_pid_m = group_id * GROUP_SIZE_M
63
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
64
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
65
- pid_n = (pid % num_pid_in_group) // group_size_m
66
- if pid_m > pid_n:
67
- return
68
-
69
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
71
- offs_k = tl.arange(0, BLOCK_SIZE_K)
72
- # we use a & b ptrs to denote different rows of x.
73
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
75
-
76
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
77
-
78
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
79
- a = tl.load(a_ptrs,
80
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
81
- other=0.0)
82
- b = tl.load(b_ptrs,
83
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
84
- other=0.0)
85
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
86
- a_ptrs += BLOCK_SIZE_K * stride_xk
87
- b_ptrs += BLOCK_SIZE_K * stride_xk
88
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
89
- # https://github.com/triton-lang/triton/issues/2252
90
- c = accumulator.to(x.dtype.element_ty)
91
-
92
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
94
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
95
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
96
- tl.store(c_ptrs, c, mask=c_mask)
97
-
98
- # transpose and copy
99
- if pid_m < pid_n:
100
- ct_ptrs = y + stride_ym * offs_cn[:,
101
- None] + stride_yn * offs_cm[None, :]
102
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
103
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
-
105
-
106
- @torch.library.custom_op("muon::matmul_transpose_assign",
107
- mutates_args=("d_out", ))
108
- def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
- """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
110
- d_in = d_in.contiguous()
111
- M, K = d_in.shape
112
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
113
- M, META['BLOCK_SIZE_M']), )
114
- with torch.cuda.device(d_in.device.index):
115
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
- d_out.stride(0), d_out.stride(1))
117
-
118
-
119
- @matmul_transpose_assign.register_fake
120
- def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
- """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "python-depends": []
3
- }
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/muon.py DELETED
@@ -1,1068 +0,0 @@
1
- import logging
2
- import types
3
- from collections import defaultdict
4
- from typing import Any
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from torch.distributed.tensor import DTensor, Replicate, Shard
9
- from torch.profiler import record_function
10
-
11
- from .adamw import _placement_cache, _tensor_cache, step_adamw
12
- from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
- get_default_muon_param_groups, is_expert_param, update_p)
15
- from .cpu_offload import CPUOffloadPool
16
- from .distributed.utils import (_is_shard, construct_shard_mesh,
17
- get_slices_of_dtensor)
18
- from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
- _zeropower_via_newtonschulz5,
20
- zeropower_via_newtonschulz5,
21
- zeropower_via_newtonschulz5_batched)
22
- from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
- from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def _expand_expert_params(names, params, expert_keys):
29
- """Expand expert params by splitting on dim 0 (expert dimension).
30
-
31
- Params whose name matches any key in ``expert_keys`` are treated as
32
- expert-parallel tensors. Their outermost dimension is the expert
33
- dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
34
- ``nn.Parameter`` views so that in-place updates propagate back to
35
- the original storage.
36
-
37
- Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
38
- if they are expert params, their key must be added to ``expert_keys``.
39
-
40
- The grad must already be set on each expert param (e.g. after momentum).
41
-
42
- For DTensor expert params, placements that shard on dim 0 (expert dim)
43
- are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
44
- preserved: each 2D slice is wrapped as a DTensor on the corresponding
45
- submesh so the parallel pipeline handles the TP communication.
46
- """
47
- expanded_names = []
48
- expanded_params = []
49
-
50
- for n, p in zip(names, params):
51
- is_expert = is_expert_param(n, expert_keys)
52
- is_dtensor = isinstance(p.data, DTensor)
53
-
54
- if is_expert:
55
- if is_dtensor:
56
- logger.debug(
57
- "[expand_expert] %s: expert DTensor, shape=%s, "
58
- "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
- p.placements, p.device_mesh.mesh_dim_names,
60
- p.to_local().shape)
61
- else:
62
- logger.debug(
63
- "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
- p.data.shape)
65
-
66
- if not is_expert:
67
- assert p.data.ndim <= 2, (
68
- f"Param {n} has ndim={p.data.ndim} but does not match "
69
- f"expert_keys={expert_keys}. If this is an expert param, "
70
- f"add its key to expert_keys.")
71
- expanded_names.append(n)
72
- expanded_params.append(p)
73
- continue
74
-
75
- g = p.grad
76
- assert g is not None, (
77
- f"Expert param {n} must have grad set before expansion")
78
-
79
- tp_mesh = None
80
- tp_placements_2d = None
81
-
82
- if is_dtensor:
83
- local_data = p.to_local()
84
- local_grad = g.to_local() if isinstance(g, DTensor) else g
85
-
86
- # Find non-dim-0 shard placements (e.g. TP sharding).
87
- # After splitting on dim 0, Shard(k) becomes Shard(k-1).
88
- tp_dim_indices = []
89
- tp_placements_2d = []
90
- for i, pl in enumerate(p.placements):
91
- if _is_shard(pl) and pl.dim != 0:
92
- tp_dim_indices.append(i)
93
- tp_placements_2d.append(Shard(pl.dim - 1))
94
-
95
- if tp_dim_indices:
96
- tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
97
- for i in tp_dim_indices)
98
- if len(tp_dim_names) == 1:
99
- tp_mesh = p.device_mesh[tp_dim_names[0]]
100
- else:
101
- tp_mesh = p.device_mesh[tp_dim_names]
102
- else:
103
- local_data = p.data
104
- local_grad = g
105
-
106
- # Expand: split dim 0, reshape each slice to 2D.
107
- num_local_experts = local_data.shape[0]
108
- for i in range(num_local_experts):
109
- slice_data = local_data[i]
110
- slice_grad = local_grad[i]
111
-
112
- if tp_mesh is not None:
113
- # Wrap as DTensor on TP submesh so the pipeline handles
114
- # TP communication (gather/scatter across TP ranks).
115
- dt_data = DTensor.from_local(slice_data,
116
- device_mesh=tp_mesh,
117
- placements=tp_placements_2d)
118
- dt_grad = DTensor.from_local(slice_grad,
119
- device_mesh=tp_mesh,
120
- placements=tp_placements_2d)
121
- expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
122
- expert_param.grad = dt_grad
123
- else:
124
- expert_param = torch.nn.Parameter(slice_data,
125
- requires_grad=False)
126
- expert_param.grad = slice_grad
127
-
128
- expanded_names.append(f"{n}[{i}]")
129
- expanded_params.append(expert_param)
130
-
131
- p.grad = None # allow expert grad storage to be freed after pipeline
132
-
133
- return expanded_names, expanded_params
134
-
135
-
136
- class Muon(torch.optim.Optimizer):
137
- """
138
- Muon - MomentUm Orthogonalized by Newton-schulz
139
-
140
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
141
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
142
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
143
- the advantage that it can be stably run in bfloat16 on the GPU.
144
-
145
- Some warnings:
146
- - We believe this optimizer is unlikely to work well for training with small batch size.
147
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
148
-
149
- Arguments:
150
- model: The model to be optimized by Muon.
151
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
152
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
153
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
154
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
155
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
156
- weight_decay: The weight decay for Muon and AdamW.
157
- Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
158
- adamw_lr: The learning rate for the internal AdamW.
159
- adamw_betas: The betas for the internal AdamW.
160
- adamw_eps: The epsilon for the internal AdamW.
161
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
162
- debug: Whether to print debug information.
163
- clip_info : Configuration for QK clipping. Expected keys:
164
- - "q_indices" (list[int]): Indices of query heads to consider.
165
- - "k_indices" (list[int]): Indices of key heads to consider.
166
- - "head_dim" (int): Dimensionality of each attention head.
167
- - "threshold" (float): Threshold value; heads whose QK logits exceed
168
- this value will be scaled down.
169
- Default is:
170
- {
171
- "q_indices": [],
172
- "k_indices": [],
173
- "head_dim": 128,
174
- "threshold": 100
175
- }
176
- warmup_step : How many all2all gather, compute operations are launched in advance
177
- before the corresponding all2all scatter steps begin.
178
- A higher warmup_step increases memory usage but can improve
179
- performance by overlapping communication.
180
- Parallel muon only.
181
- chunk_size : Batch size of parameters to process in each
182
- all2all gather/compute/scatter step.
183
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
- For testing purpose only.
186
- expert_keys: List of strings to identify expert-parallel parameters.
187
- If any key appears in a parameter's name, its outermost
188
- dimension is treated as the expert dimension and expanded
189
- into per-expert 2D params for Muon. For example,
190
- ``expert_keys=["experts"]`` matches any param whose name
191
- contains "experts". 3D+ params not matched by any key
192
- will raise an error.
193
- """
194
-
195
- def __init__(self,
196
- params,
197
- lr=1e-3,
198
- momentum=0.95,
199
- nesterov=True,
200
- ns_steps=5,
201
- weight_decay=0.1,
202
- adamw_betas=(0.9, 0.95),
203
- adamw_eps=1e-8,
204
- none_grad=True,
205
- debug=False,
206
- clip_config=None,
207
- warmup_step=5,
208
- chunk_size=-1,
209
- use_distributed_muon=False,
210
- expert_keys=None):
211
- defaults = dict(
212
- lr=lr,
213
- weight_decay=weight_decay,
214
- momentum=momentum,
215
- nesterov=nesterov,
216
- ns_steps=ns_steps,
217
- adamw_betas=adamw_betas,
218
- adamw_eps=adamw_eps,
219
- none_grad=none_grad,
220
- use_muon=True,
221
- )
222
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
223
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
224
-
225
- if isinstance(params, types.GeneratorType):
226
- raise ValueError(error_message.format(idx=0) + instruction_code)
227
- for _idx, param_group in enumerate(params):
228
- if param_group.get("use_muon", None) is None:
229
- raise ValueError(
230
- error_message.format(idx=_idx) + instruction_code)
231
- super().__init__(params, defaults)
232
-
233
- self.debug = debug
234
- self.clip_config = clip_config if clip_config is not None else {
235
- "q_indices": [],
236
- "k_indices": [],
237
- "head_dim": 128,
238
- "threshold": 100,
239
- }
240
- self.warmup_step = warmup_step
241
- self.chunk_size = chunk_size
242
- self.use_distributed_muon = use_distributed_muon
243
- self.expert_keys = expert_keys
244
- self.cpu_offload = False
245
- self._cpu_offload_pool: CPUOffloadPool | None = None
246
- self._offload_initialized = False
247
- self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
- self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
249
-
250
- def _calc_flops(self, G, steps):
251
- assert len(G.shape) == 2
252
- M, N = G.shape
253
- if M > N:
254
- M, N = N, M
255
-
256
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
257
-
258
- def get_shard_mesh(self, p):
259
- """
260
- Get the shard mesh for a parameter p on the given rank.
261
- """
262
- assert isinstance(
263
- p, DTensor), "Parallel Muon only supports DTensor parameters."
264
-
265
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
266
- p.placements, p.device_mesh)
267
-
268
- return shard_mesh, shard_pg, shard_placements
269
-
270
- def init_state_and_assign_params(self, names, params, group, qk_logits):
271
- param_to_state = {}
272
- param_to_flops = {}
273
-
274
- total_flops = 0
275
- for p in params:
276
- g = p.grad
277
- if g is None:
278
- continue
279
- assert g.ndim == 2, "Muon only supports 2D parameters."
280
-
281
- flops = self._calc_flops(g, group["ns_steps"])
282
- param_to_flops[id(p)] = flops
283
- total_flops += flops
284
-
285
- if self.debug:
286
- logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
287
- total_flops / 1e12)
288
-
289
- paired = list(zip(names, params))
290
-
291
- paired_sorted = sorted(paired,
292
- key=lambda x: param_to_flops[id(x[1])],
293
- reverse=True)
294
-
295
- names_sorted, params_sorted = zip(*paired_sorted)
296
- ordered_names = list(names_sorted)
297
- ordered_params = list(params_sorted)
298
-
299
- round_robin = 0
300
- mesh = ordered_params[0].device_mesh
301
- placements = ordered_params[0].placements
302
-
303
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
304
- ordered_params[0])
305
- shard_mesh_flattened = shard_mesh.mesh.flatten()
306
- num_ranks = dist.get_world_size(group=shard_pg)
307
-
308
- for n, p in zip(ordered_names, ordered_params):
309
- if mesh != p.device_mesh:
310
- raise ValueError("All parameters must be on the same mesh.")
311
- if placements != p.placements:
312
- raise ValueError("All parameters must have same placements.")
313
-
314
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
315
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
316
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
317
-
318
- # Precompute per-rank indices and numels for all-to-all.
319
- rank_indices: dict[int, tuple] = {}
320
- rank_numels: dict[int, int] = {}
321
- for r in range(num_ranks):
322
- indices = get_slices_of_dtensor(p, r, shard_mesh,
323
- shard_placements)
324
- rank_indices[r] = indices
325
- numel = 1
326
- for idx, dim_size in zip(indices, p.shape):
327
- if isinstance(idx, slice):
328
- start, stop, step = idx.indices(dim_size)
329
- numel *= max(0, (stop - start + (step - 1)) // step)
330
- else:
331
- numel *= len(idx)
332
- rank_numels[r] = numel
333
-
334
- param_to_state[id(p)] = _muon_state(
335
- worker_rank=worker_rank,
336
- process_group=shard_pg,
337
- rank_indices=rank_indices,
338
- rank_numels=rank_numels,
339
- name=n,
340
- qk_clip_state=qk_clip_state,
341
- )
342
-
343
- return param_to_state, ordered_params
344
-
345
- def base(self, names, params, group, lr, weight_decay, qk_logits):
346
- # Momentum is already applied by _step_muon before this method.
347
- for n, p in zip(names, params):
348
- g = p.grad
349
- if g is None:
350
- continue
351
-
352
- u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
353
- steps=group["ns_steps"])
354
-
355
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
356
- update_p(p, u, lr, adjusted_lr, weight_decay)
357
-
358
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
359
-
360
- scales_full = compute_scales(
361
- p, qk_clip_state) if qk_clip_state is not None else None
362
- if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state)
364
-
365
- def distributed_muon(
366
- self,
367
- names: list[str],
368
- params: list[torch.nn.Parameter],
369
- group: dict[str, Any],
370
- lr: float,
371
- weight_decay: float,
372
- qk_logits: list[torch.Tensor | DTensor] | None,
373
- ):
374
- """Batched Distributed Muon — for testing/correctness verification only.
375
-
376
- Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
377
- the full grad, then slices back to local shards. This is simpler but
378
- slower than the parallel pipeline (all2all) path, so it serves as a
379
- reference implementation for verifying correctness.
380
- """
381
- with record_function("distributed_muon"):
382
- # Momentum is already applied by _step_muon before this method.
383
- ns_steps = group["ns_steps"]
384
-
385
- # Separate plain tensors (no communication) from DTensors.
386
- plain_names, plain_params = [], []
387
- dtensor_names, dtensor_params = [], []
388
- for n, p in zip(names, params):
389
- if p.grad is None:
390
- continue
391
- if isinstance(p.data, DTensor):
392
- dtensor_names.append(n)
393
- dtensor_params.append(p)
394
- else:
395
- plain_names.append(n)
396
- plain_params.append(p)
397
-
398
- # Process plain tensors per-param (no communication).
399
- for n, p in zip(plain_names, plain_params):
400
- u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
401
- steps=ns_steps)
402
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
403
- update_p(p, u, lr, adjusted_lr, weight_decay)
404
-
405
- qk_clip_state = get_qk_clip_info(self.clip_config, n,
406
- qk_logits)
407
- scales_full = compute_scales(
408
- p, qk_clip_state) if qk_clip_state is not None else None
409
- if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state)
411
-
412
- if not dtensor_params:
413
- return
414
-
415
- # Group DTensors by (placements, mesh) for batched all-gather.
416
- placement_groups: dict[tuple,
417
- tuple[list,
418
- list]] = defaultdict(lambda: ([], []))
419
- for n, p in zip(dtensor_names, dtensor_params):
420
- key = (p.placements, p.device_mesh)
421
- placement_groups[key][0].append(n)
422
- placement_groups[key][1].append(p)
423
-
424
- logger.info(
425
- "distributed_muon: %d placement groups, %d total dtensors",
426
- len(placement_groups), len(dtensor_params))
427
-
428
- for (placements, mesh), (grp_names,
429
- grp_params) in placement_groups.items():
430
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
431
- placements, mesh)
432
- rank = dist.get_rank(shard_pg)
433
- world_size = dist.get_world_size(shard_pg)
434
-
435
- logger.info(" group: %d params, placements=%s, world_size=%d",
436
- len(grp_params), placements, world_size)
437
-
438
- # Separate params that can be batched (all shard dims evenly
439
- # divisible) from those needing per-param full_tensor
440
- # (e.g. MoE gate weights with fewer rows than shard ranks).
441
- # all_gather_into_tensor requires equal buffer sizes across
442
- # ranks, so uneven splits must use DTensor full_tensor().
443
- batch_names, batch_params = [], []
444
- single_names, single_params = [], []
445
- for n, p in zip(grp_names, grp_params):
446
- even = all(p.shape[pl.dim] %
447
- shard_mesh.mesh.shape[dim_idx] == 0
448
- for dim_idx, pl in enumerate(shard_placements))
449
- if even:
450
- batch_names.append(n)
451
- batch_params.append(p)
452
- else:
453
- single_names.append(n)
454
- single_params.append(p)
455
-
456
- # Process uneven-split params per-param via full_tensor().
457
- for n, p in zip(single_names, single_params):
458
- with record_function("distributed_muon::newton_schulz"):
459
- g_full = p.grad.full_tensor().to(COMM_DTYPE)
460
- u_full = _zeropower_via_newtonschulz5(g_full,
461
- steps=ns_steps)
462
- del g_full
463
- with record_function("distributed_muon::update"):
464
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
465
- p._local_tensor.mul_(1 - lr * weight_decay)
466
- local_indices = get_slices_of_dtensor(
467
- p, rank, shard_mesh, shard_placements)
468
- u_local = u_full[local_indices]
469
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
470
- del u_full
471
-
472
- qk_clip_state = get_qk_clip_info(
473
- self.clip_config, n, qk_logits)
474
- scales_full = compute_scales(
475
- p, qk_clip_state
476
- ) if qk_clip_state is not None else None
477
- if scales_full is not None:
478
- ratio = p.shape[0] // scales_full.shape[0]
479
- idx0 = local_indices[0]
480
- if isinstance(idx0, slice):
481
- start = idx0.start or 0
482
- idx0 = torch.arange(start,
483
- idx0.stop,
484
- device=scales_full.device)
485
- row_scales = scales_full[idx0 // ratio]
486
- p._local_tensor.mul_(row_scales.view(-1, 1))
487
-
488
- if not batch_params:
489
- continue
490
-
491
- logger.info(" batched=%d, single=%d", len(batch_params),
492
- len(single_params))
493
-
494
- # Concat all local grad shards into a single flat buffer.
495
- with record_function("distributed_muon::gather"):
496
- grad_locals = [
497
- p.grad.to_local().to(COMM_DTYPE).flatten()
498
- for p in batch_params
499
- ]
500
- numels = [g.numel() for g in grad_locals]
501
- grad_concat = torch.cat(grad_locals)
502
- del grad_locals
503
-
504
- # Single all-gather (replaces N separate full_tensor).
505
- grad_gathered = torch.empty(
506
- grad_concat.numel() * world_size,
507
- dtype=COMM_DTYPE,
508
- device="cuda",
509
- )
510
- dist.all_gather_into_tensor(grad_gathered,
511
- grad_concat,
512
- group=shard_pg)
513
-
514
- total_numel = grad_concat.numel()
515
- del grad_concat
516
-
517
- # Precompute per-param offsets within the concat buffer.
518
- offsets = []
519
- off = 0
520
- for ne in numels:
521
- offsets.append(off)
522
- off += ne
523
-
524
- # Per-param: reconstruct full grad → NS → local update.
525
- for i, (n, p) in enumerate(zip(batch_names, batch_params)):
526
- with record_function("distributed_muon::newton_schulz"):
527
- g_full = torch.empty(p.shape,
528
- dtype=COMM_DTYPE,
529
- device="cuda")
530
- for r in range(world_size):
531
- r_start = r * total_numel + offsets[i]
532
- shard = grad_gathered[r_start:r_start + numels[i]]
533
- indices = get_slices_of_dtensor(
534
- p, r, shard_mesh, shard_placements)
535
- g_full[indices] = shard.reshape(
536
- g_full[indices].shape)
537
-
538
- u_full = _zeropower_via_newtonschulz5(g_full,
539
- steps=ns_steps)
540
- del g_full
541
-
542
- with record_function("distributed_muon::update"):
543
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
544
- p._local_tensor.mul_(1 - lr * weight_decay)
545
- local_indices = get_slices_of_dtensor(
546
- p, rank, shard_mesh, shard_placements)
547
- u_local = u_full[local_indices]
548
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
549
- del u_full
550
-
551
- qk_clip_state = get_qk_clip_info(
552
- self.clip_config, n, qk_logits)
553
- scales_full = compute_scales(
554
- p, qk_clip_state
555
- ) if qk_clip_state is not None else None
556
- if scales_full is not None:
557
- ratio = p.shape[0] // scales_full.shape[0]
558
- idx0 = local_indices[0]
559
- if isinstance(idx0, slice):
560
- start = idx0.start or 0
561
- idx0 = torch.arange(start,
562
- idx0.stop,
563
- device=scales_full.device)
564
- row_scales = scales_full[idx0 // ratio]
565
- p._local_tensor.mul_(row_scales.view(-1, 1))
566
-
567
- def _setup_parallel(self, names, params, group, qk_logits):
568
- """Compute (or retrieve cached) parallel pipeline metadata.
569
-
570
- Returns:
571
- (ordered_params, param_to_state, rank, chunk_size)
572
- """
573
- cache_key = tuple(names)
574
-
575
- if cache_key not in self._parallel_cache:
576
- # First call: compute metadata and populate cache.
577
- param_to_state, ordered_params = self.init_state_and_assign_params(
578
- names, params, group, qk_logits)
579
-
580
- shard_pg = param_to_state[id(ordered_params[0])].process_group
581
- rank = dist.get_rank(group=shard_pg)
582
-
583
- if self.chunk_size == -1:
584
- shard_ranks = dist.get_world_size(shard_pg)
585
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
586
- elif self.chunk_size > 0:
587
- chunk_size = self.chunk_size
588
- else:
589
- raise ValueError(
590
- "chunk_size must be -1 or a positive integer.")
591
-
592
- ordered_names = [
593
- param_to_state[id(p)].name for p in ordered_params
594
- ]
595
- name_to_state = {
596
- param_to_state[id(p)].name: param_to_state[id(p)]
597
- for p in ordered_params
598
- }
599
- self._parallel_cache[cache_key] = {
600
- 'ordered_names': ordered_names,
601
- 'name_to_state': name_to_state,
602
- 'rank': rank,
603
- 'chunk_size': chunk_size,
604
- }
605
- else:
606
- # Cached path: rebuild param_to_state with current id(p) keys.
607
- cache = self._parallel_cache[cache_key]
608
- rank = cache['rank']
609
- chunk_size = cache['chunk_size']
610
-
611
- name_to_param = dict(zip(names, params))
612
- ordered_params = [name_to_param[n] for n in cache['ordered_names']]
613
-
614
- param_to_state = {}
615
- for p, n in zip(ordered_params, cache['ordered_names']):
616
- cached_state = cache['name_to_state'][n]
617
- param_to_state[id(p)] = _muon_state(
618
- worker_rank=cached_state.worker_rank,
619
- process_group=cached_state.process_group,
620
- rank_indices=cached_state.rank_indices,
621
- rank_numels=cached_state.rank_numels,
622
- name=n,
623
- qk_clip_state=get_qk_clip_info(self.clip_config, n,
624
- qk_logits),
625
- )
626
-
627
- return ordered_params, param_to_state, rank, chunk_size
628
-
629
- def parallel(self,
630
- names,
631
- params,
632
- group,
633
- lr,
634
- weight_decay,
635
- qk_logits,
636
- prelaunch_gather=None):
637
- """
638
- Perform a parallel optimization step using Muon.
639
-
640
- Parameters are chunked and each chunk is processed by a
641
- :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
642
- interleaves multiple chunks so that communication and computation
643
- overlap across chunks (the same overlap previously achieved by the
644
- warmup + main-loop index scheduling).
645
-
646
- If ``prelaunch_gather`` is provided, it is passed to the first
647
- chunk's generator to skip re-launching the already in-flight
648
- A2A gather.
649
- """
650
-
651
- # Momentum is already applied by _step_muon before this method.
652
-
653
- ordered_params, param_to_state, rank, chunk_size = (
654
- self._setup_parallel(names, params, group, qk_logits))
655
-
656
- def pipelines():
657
- first = True
658
- for start in range(0, len(ordered_params), chunk_size):
659
- chunk = ordered_params[start:start + chunk_size]
660
- if chunk:
661
- kwargs = dict(
662
- params=chunk,
663
- param_to_state=param_to_state,
664
- rank=rank,
665
- ns_steps=group["ns_steps"],
666
- lr=lr,
667
- weight_decay=weight_decay,
668
- none_grad=group["none_grad"],
669
- )
670
- if first and prelaunch_gather is not None:
671
- kwargs['prelaunch_gather'] = prelaunch_gather
672
- first = False
673
- yield muon_chunk_pipeline(**kwargs)
674
-
675
- with record_function("muon::pipeline"):
676
- run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
677
-
678
- def _step_muon(self, group, qk_logits=None):
679
- params = group["params"]
680
- lr = group["lr"]
681
- weight_decay = group["weight_decay"]
682
- momentum = group["momentum"]
683
- names = group["names"]
684
-
685
- # Apply momentum to all params before routing/expansion.
686
- # Batched using _foreach_* ops (compiled, fullgraph=True).
687
- with record_function("muon::momentum"):
688
- active_params = [p for p in params if p.grad is not None]
689
- if active_params:
690
- # Ensure momentum buffers exist (avoid zeros_like when already present).
691
- for p in active_params:
692
- if "momentum_buffer" not in self.state[p]:
693
- self.state[p]["momentum_buffer"] = torch.zeros_like(
694
- p.grad)
695
-
696
- # Extract local tensors for compiled batch function.
697
- local_grads = [
698
- p.grad._local_tensor
699
- if isinstance(p.grad, DTensor) else p.grad
700
- for p in active_params
701
- ]
702
- local_bufs = [
703
- self.state[p]["momentum_buffer"]._local_tensor
704
- if isinstance(self.state[p]["momentum_buffer"], DTensor)
705
- else self.state[p]["momentum_buffer"]
706
- for p in active_params
707
- ]
708
-
709
- # Wrap momentum as tensor for torch.compile.
710
- batch_pre_ortho(local_grads, local_bufs,
711
- torch.tensor(momentum), group["nesterov"])
712
-
713
- # For non-nesterov, the result is the momentum buffer.
714
- if not group["nesterov"]:
715
- for p in active_params:
716
- p.grad = self.state[p]["momentum_buffer"]
717
-
718
- # Identify batched experts for deferred NS.
719
- # Detection is cheap (condition checks only); actual NS compute is
720
- # deferred so it can overlap with the first chunk's A2A gather.
721
- deferred_expert_work = []
722
- if self.expert_keys:
723
- batched_expert_indices = []
724
- for i, (n, p) in enumerate(zip(names, params)):
725
- if not (is_expert_param(n, self.expert_keys)
726
- and p.grad is not None):
727
- continue
728
- # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
729
- if isinstance(p.data, DTensor):
730
- has_tp = any(
731
- _is_shard(pl) and pl.dim != 0 for pl in p.placements)
732
- if has_tp:
733
- continue
734
- batched_expert_indices.append(i)
735
-
736
- if batched_expert_indices:
737
- # Save refs for deferred NS; free grads from param list.
738
- for i in batched_expert_indices:
739
- p = params[i]
740
- g = p.grad
741
- local_g = (g._local_tensor
742
- if isinstance(g, DTensor) else g)
743
- local_data = (p.data._local_tensor if isinstance(
744
- p.data, DTensor) else p.data)
745
- deferred_expert_work.append((local_data, local_g))
746
- p.grad = None
747
-
748
- # Remove batched experts from lists before expansion.
749
- keep = sorted(
750
- set(range(len(params))) - set(batched_expert_indices))
751
- names = [names[i] for i in keep]
752
- params = [params[i] for i in keep]
753
-
754
- def _run_deferred_expert_ns():
755
- """Execute deferred batched expert NS."""
756
- if not deferred_expert_work:
757
- return
758
- with record_function("muon::batched_expert_ns"):
759
- ns_steps = group["ns_steps"]
760
- for local_data, local_g in deferred_expert_work:
761
- u = zeropower_via_newtonschulz5_batched(
762
- local_g.to(COMM_DTYPE), steps=ns_steps)
763
- adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
764
- local_data.mul_(1 - lr * weight_decay)
765
- local_data.add_(u, alpha=-adjusted_lr)
766
-
767
- # Expand expert params by splitting on dim 0.
768
- logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
769
- len(params), self.expert_keys)
770
- if self.expert_keys:
771
- cache_key = tuple(id(p) for p in params)
772
- cache = self._expert_expand_cache.get(cache_key)
773
-
774
- if cache is None:
775
- # Cold path: full expansion + build cache metadata.
776
- exp_names, exp_params = _expand_expert_params(
777
- names, params, self.expert_keys)
778
-
779
- # Build per-expert-group info for hot-path grad updates.
780
- grad_info = []
781
- exp_idx = 0
782
- for orig_idx, (n, p) in enumerate(zip(names, params)):
783
- if not is_expert_param(n, self.expert_keys):
784
- exp_idx += 1
785
- continue
786
-
787
- is_dt = isinstance(p.data, DTensor)
788
- num_experts = (p.to_local() if is_dt else p.data).shape[0]
789
-
790
- # Detect TP mesh from the first expanded expert param.
791
- tp_mesh = None
792
- tp_pls = None
793
- sample = exp_params[exp_idx]
794
- if isinstance(sample.data, DTensor):
795
- tp_mesh = sample.data.device_mesh
796
- tp_pls = list(sample.data.placements)
797
-
798
- grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
799
- tp_mesh, tp_pls))
800
- exp_idx += num_experts
801
-
802
- self._expert_expand_cache[cache_key] = {
803
- 'names': exp_names,
804
- 'params': exp_params,
805
- 'grad_info': grad_info,
806
- }
807
- names, params = exp_names, exp_params
808
- else:
809
- # Hot path: reuse cached params, only update expert grads.
810
- for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
811
- tp_pls) in cache['grad_info']:
812
- p = params[orig_idx]
813
- g = p.grad
814
- local_grad = (g.to_local()
815
- if is_dt and isinstance(g, DTensor) else g)
816
- for i in range(num_experts):
817
- expert_p = cache['params'][exp_start + i]
818
- sg = local_grad[i]
819
- if tp_mesh is not None:
820
- expert_p.grad = DTensor.from_local(
821
- sg, device_mesh=tp_mesh, placements=tp_pls)
822
- else:
823
- expert_p.grad = sg
824
- p.grad = None
825
-
826
- names = cache['names']
827
- params = cache['params']
828
- else:
829
- names, params = _expand_expert_params(names, params,
830
- self.expert_keys)
831
- logger.debug("[_step_muon] after expand: %d params", len(params))
832
-
833
- param_dtensors = []
834
- name_dtensors = []
835
-
836
- param_tensors = []
837
- name_tensors = []
838
-
839
- # distributed_muon is a reference implementation for testing only.
840
- # The parallel pipeline (all2all) path below is the production path.
841
- if self.use_distributed_muon:
842
- _run_deferred_expert_ns()
843
- self.distributed_muon(names=names,
844
- params=params,
845
- group=group,
846
- lr=lr,
847
- weight_decay=weight_decay,
848
- qk_logits=qk_logits)
849
- return
850
-
851
- for n, p in zip(names, params):
852
- if p is None or p.grad is None:
853
- continue
854
- if isinstance(p.data, DTensor):
855
- if all(
856
- isinstance(placement, Replicate)
857
- for placement in p.placements):
858
- logger.debug(
859
- "[route] %s → base (DTensor all-Replicate), "
860
- "shape=%s, placements=%s", n, p.shape, p.placements)
861
- param_tensors.append(p)
862
- name_tensors.append(n)
863
- else:
864
- logger.debug(
865
- "[route] %s → parallel (DTensor), shape=%s, "
866
- "placements=%s, mesh=%s", n, p.shape, p.placements,
867
- p.device_mesh.mesh_dim_names)
868
- param_dtensors.append(p)
869
- name_dtensors.append(n)
870
- elif isinstance(p.data, torch.Tensor):
871
- logger.debug("[route] %s → base (plain tensor), shape=%s", n,
872
- p.data.shape)
873
- param_tensors.append(p)
874
- name_tensors.append(n)
875
- else:
876
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
877
-
878
- logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
879
- f"{len(param_tensors)} Tensors → base")
880
-
881
- def group_dtensors(dtensors, names):
882
- # To support different placements, we group parameters by placements
883
- # and run parallel Muon on each group.
884
-
885
- placement_to_params = defaultdict(lambda: ([], []))
886
-
887
- assert len(dtensors) == len(names)
888
- for p, n in zip(dtensors, names):
889
- placement_to_params[tuple([p.placements,
890
- p.device_mesh])][0].append(n)
891
- placement_to_params[tuple([p.placements,
892
- p.device_mesh])][1].append(p)
893
- return placement_to_params
894
-
895
- if len(param_dtensors) > 0:
896
- if not dist.is_initialized():
897
- raise RuntimeError(
898
- "Parallel Muon requires torch.distributed to be initialized."
899
- )
900
-
901
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
902
-
903
- # Pre-launch the first chunk's A2A gather so that the NCCL
904
- # communication overlaps with the (deferred) batched expert NS
905
- # compute on the default CUDA stream.
906
- prelaunch = None
907
- if deferred_expert_work:
908
- first_names, first_params = next(iter(dtensor_group.values()))
909
- ordered, pts, rnk, csz = self._setup_parallel(
910
- first_names, first_params, group, qk_logits)
911
- first_chunk = ordered[:csz]
912
- if first_chunk:
913
- prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
914
- group["none_grad"])
915
-
916
- _run_deferred_expert_ns()
917
-
918
- first_group = True
919
- for _, (names, params) in dtensor_group.items():
920
- pg = prelaunch if first_group else None
921
- first_group = False
922
- self.parallel(
923
- names,
924
- params,
925
- group,
926
- lr=lr,
927
- weight_decay=weight_decay,
928
- qk_logits=qk_logits,
929
- prelaunch_gather=pg,
930
- )
931
- else:
932
- _run_deferred_expert_ns()
933
-
934
- if len(param_tensors) > 0:
935
- self.base(
936
- name_tensors,
937
- param_tensors,
938
- group,
939
- lr=lr,
940
- weight_decay=weight_decay,
941
- qk_logits=qk_logits,
942
- )
943
-
944
- def _register_states_for_offload(self):
945
- """Register all optimizer state tensors with the CPU offload pool.
946
-
947
- Called once after the first step when states have been lazily created.
948
- Offloads all param states (momentum buffers for Muon, moment1/moment2
949
- for AdamW) to free GPU memory between steps.
950
- """
951
- pool = self._cpu_offload_pool
952
- tracked = 0
953
- for group in self.param_groups:
954
- for p in group["params"]:
955
- if p not in self.state:
956
- continue
957
- state = self.state[p]
958
- if group.get("use_muon", False):
959
- if "momentum_buffer" in state:
960
- pool.track(state["momentum_buffer"])
961
- tracked += 1
962
- else:
963
- if "moment1" in state:
964
- pool.track(state["moment1"])
965
- if "moment2" in state:
966
- pool.track(state["moment2"])
967
- tracked += 1
968
- logger.info("[CPUOffload] Registered %d param states for offload",
969
- tracked)
970
-
971
- @torch.no_grad
972
- def step(self, closure=None, qk_logits=None):
973
- """Perform a single optimization step.
974
-
975
- Args:
976
- closure (Callable, optional): A closure that reevaluates the model
977
- and returns the loss.
978
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
979
- to 1D tensors of shape (num_heads,), representing the maximum
980
- QK logits across all tokens, computed as
981
- (1 / sqrt(head_dim)) * (Q @ K^T).
982
- """
983
- loss = None
984
- if closure is not None:
985
- with torch.enable_grad():
986
- loss = closure()
987
-
988
- # H2D: reload optimizer states from CPU before computation.
989
- if self.cpu_offload and self._offload_initialized:
990
- self._cpu_offload_pool.reload()
991
-
992
- logger.debug("[Muon.step] expert_keys=%s, %d param groups",
993
- self.expert_keys, len(self.param_groups))
994
-
995
- for i, group in enumerate(self.param_groups):
996
- if group["use_muon"]:
997
- logger.debug("[Muon.step] group %d: use_muon=True, %d params",
998
- i, len(group["params"]))
999
- self._step_muon(group, qk_logits=qk_logits)
1000
- else:
1001
- logger.debug(
1002
- "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1003
- i, len(group["params"]))
1004
- step_adamw(self.state, group)
1005
-
1006
- # D2H: offload optimizer states to CPU after computation.
1007
- if self.cpu_offload:
1008
- if not self._offload_initialized:
1009
- if self._cpu_offload_pool is None:
1010
- self._cpu_offload_pool = CPUOffloadPool()
1011
- self._register_states_for_offload()
1012
- self._offload_initialized = True
1013
- self._cpu_offload_pool.offload()
1014
-
1015
- return loss
1016
-
1017
- # ------------------------------------------------------------------
1018
- # CPU offload public helpers
1019
- # ------------------------------------------------------------------
1020
-
1021
- def turn_on_cpu_offload(self):
1022
- """Enable CPU offload for optimizer states."""
1023
- if self.cpu_offload:
1024
- return
1025
- logger.info("[Muon] turn_on_cpu_offload")
1026
- self.cpu_offload = True
1027
- if not self.state:
1028
- return
1029
- self._cpu_offload_pool = CPUOffloadPool()
1030
- self._offload_initialized = False
1031
- self._register_states_for_offload()
1032
- self._offload_initialized = True
1033
- self._cpu_offload_pool.offload()
1034
-
1035
- def turn_off_cpu_offload(self):
1036
- """Disable CPU offload and keep optimizer states resident on GPU."""
1037
- if not self.cpu_offload:
1038
- return
1039
- logger.info("[Muon] turn_off_cpu_offload")
1040
- if self._offload_initialized:
1041
- self._cpu_offload_pool.reload()
1042
- torch.cuda.current_stream().synchronize()
1043
- self._cpu_offload_pool = None
1044
- self._offload_initialized = False
1045
- self.cpu_offload = False
1046
-
1047
- # ------------------------------------------------------------------
1048
- # Checkpoint support for cpu_offload
1049
- # ------------------------------------------------------------------
1050
-
1051
- def state_dict(self) -> dict:
1052
- if self.cpu_offload:
1053
- raise RuntimeError(
1054
- "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
- )
1056
- return super().state_dict()
1057
-
1058
- def load_state_dict(self, state_dict: dict) -> None:
1059
- if self.cpu_offload:
1060
- raise RuntimeError(
1061
- "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
- )
1063
- super().load_state_dict(state_dict)
1064
-
1065
- # Invalidate adamw.py's module-level tensor caches so that
1066
- # the next step rebuilds them with the newly loaded state tensors.
1067
- _placement_cache.clear()
1068
- _tensor_cache.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py DELETED
@@ -1,240 +0,0 @@
1
- from itertools import repeat
2
- from math import inf, sqrt
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from .matmul_transpose_triton import matmul_transpose_assign
8
-
9
- COMM_DTYPE = torch.bfloat16
10
- DEFAULT_CHUNK_SIZE_RATIO = 4
11
-
12
-
13
- def _optimal_quintic(l, u, max_iter=1000):
14
- """
15
- Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
- to the constant function x -> 1 over the interval [l, u].
17
-
18
- Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
- approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
- two interior equioscillation nodes q, r until convergence. Returns the
21
- closed-form equioscillating solution when l ≈ u.
22
-
23
- Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
- (NaN or inf). Raises RuntimeError if convergence is not reached within
25
- max_iter iterations.
26
- """
27
- assert 0 <= l <= u
28
- if 1 - 5e-6 <= l / u:
29
- return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
- q = (3 * l + u) / 4
31
- r = (l + 3 * u) / 4
32
- E = inf
33
- for _ in range(max_iter):
34
- old_E = E
35
- LHS = np.array(
36
- [
37
- [l, l**3, l**5, 1],
38
- [q, q**3, q**5, -1],
39
- [r, r**3, r**5, 1],
40
- [u, u**3, u**5, -1],
41
- ]
42
- )
43
- a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
- if not np.all(np.isfinite([a, b, c, E])):
45
- raise ValueError(
46
- f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
- )
48
- q, r = np.sqrt(
49
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
- )
51
- if not np.all(np.isfinite([q, r])):
52
- raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
53
- if abs(old_E - E) <= 1e-15:
54
- break
55
- else:
56
- raise RuntimeError(
57
- f"_optimal_quintic: did not converge after {max_iter} iterations"
58
- )
59
- return float(a), float(b), float(c)
60
-
61
-
62
- def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
63
- """
64
- Compute the Polar Express coefficient series for `num_iters` quintic iterations.
65
-
66
- Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
67
- compose to map singular values from [l, 1] toward 1. At each step:
68
- 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
69
- prevents near-zero singular values from stalling by raising the effective
70
- lower bound; if it is active (cushion*u > l), the coefficients are
71
- rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
72
- 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
73
- last iteration, providing numerical headroom at the cost of a slightly slower
74
- final convergence step.
75
- 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
76
-
77
- Returns a list of (a, b, c) tuples, one per iteration.
78
-
79
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
80
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
81
- """
82
- u = 1
83
- assert 0 <= l <= u
84
- safety_factor = 1 + safety_factor_eps
85
- coefficients = []
86
- for iter in range(num_iters):
87
- a, b, c = _optimal_quintic(max(l, cushion * u), u)
88
- if cushion * u > l:
89
- pl = a * l + b * l**3 + c * l**5
90
- pu = a * u + b * u**3 + c * u**5
91
- rescaler = 2 / (pl + pu)
92
- a *= rescaler
93
- b *= rescaler
94
- c *= rescaler
95
- if iter < num_iters - 1:
96
- a /= safety_factor
97
- b /= safety_factor**3
98
- c /= safety_factor**5
99
- coefficients.append((a, b, c))
100
- l = a * l + b * l**3 + c * l**5
101
- u = 2 - l
102
- return coefficients
103
-
104
-
105
- # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
106
- # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
107
- # approximant to x->1 over the current singular-value interval, computed once at
108
- # import time and reused across all optimizer steps.
109
- #
110
- # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
111
- # - Former: empirically tuned to maximize slope at zero; did not converge
112
- # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
113
- # of the true polar factor UV^T.
114
- # - Polar Express: analytically optimal per step, adapting to the shrinking
115
- # singular-value interval [l, u] as iterations progress; converges all
116
- # singular values to 1, producing the exact polar factor UV^T.
117
- _coeffs_list = _optimal_composition(
118
- l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
- )
120
-
121
-
122
- # This code is adapted from:
123
- # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
124
- # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
125
- # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
126
- @torch.no_grad()
127
- def _zeropower_via_newtonschulz5(G, steps):
128
- """
129
- Compute the polar factor of G via the Polar Express method.
130
-
131
- Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
132
- are the Polar Express coefficients from `_coeffs_list`. Each step is the
133
- optimal odd quintic approximant to x -> 1 over the current singular-value
134
- interval, minimizing the maximum approximation error (Remez / minimax criterion).
135
- The composition maps singular values from [l, 1] to near 1, producing the
136
- polar factor (orthogonal factor in the polar decomposition G = UP).
137
-
138
- `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
139
- cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
140
-
141
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
142
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
143
- """
144
- assert len(G.shape) == 2
145
- assert G.dtype == COMM_DTYPE
146
- X = G # no manual typecast
147
-
148
- if G.size(0) > G.size(1):
149
- X = X.T
150
-
151
- X = X / (X.norm() + 1e-7)
152
- hs = _coeffs_list[:steps] + list(
153
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
- )
155
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
- # Perform the NS iterations
158
- for a, b, c in hs:
159
- matmul_transpose_assign(X, buf1)
160
- matmul_transpose_assign(buf1, buf2)
161
- buf1.mul_(b).add_(buf2, alpha=c)
162
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
163
-
164
- if G.size(0) > G.size(1):
165
- X = X.T
166
-
167
- return X
168
-
169
-
170
- @torch.no_grad()
171
- def _zeropower_via_newtonschulz5_batched(G, steps):
172
- """Batched polar factor computation for 3D (E, out, in) tensors.
173
-
174
- Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
175
- ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
176
- processing all E expert matrices in a single batched call.
177
- """
178
- assert len(G.shape) == 3
179
- assert G.dtype == COMM_DTYPE
180
- X = G
181
-
182
- if G.size(1) > G.size(2):
183
- X = X.transpose(-2, -1)
184
-
185
- # Per-expert Frobenius norm.
186
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
-
188
- hs = _coeffs_list[:steps] + list(
189
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
- )
191
- for a, b, c in hs:
192
- buf1 = torch.bmm(X, X.transpose(-2, -1))
193
- buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
194
- buf1.mul_(b).add_(buf2, alpha=c)
195
- X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
196
-
197
- if G.size(1) > G.size(2):
198
- X = X.transpose(-2, -1)
199
-
200
- return X
201
-
202
-
203
- _ns_per_shape: dict[tuple[int, ...], callable] = {}
204
- _use_compile = True
205
-
206
-
207
- def set_ns_compile(enabled: bool):
208
- """Toggle torch.compile for Newton-Schulz iteration."""
209
- global _use_compile
210
- _use_compile = enabled
211
-
212
-
213
- def zeropower_via_newtonschulz5(G, steps=5):
214
- if not _use_compile:
215
- return _zeropower_via_newtonschulz5(G, steps)
216
- key = G.shape
217
- if key not in _ns_per_shape:
218
- _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
219
- options={
220
- "triton.cudagraphs": True,
221
- "shape_padding": False
222
- })
223
- torch.compiler.cudagraph_mark_step_begin()
224
- return _ns_per_shape[key](G, steps).clone()
225
-
226
-
227
- def zeropower_via_newtonschulz5_batched(G, steps=5):
228
- """Compile-cached batched Newton-Schulz for 3D expert tensors."""
229
- if not _use_compile:
230
- return _zeropower_via_newtonschulz5_batched(G, steps)
231
- key = G.shape
232
- if key not in _ns_per_shape:
233
- _ns_per_shape[key] = torch.compile(
234
- _zeropower_via_newtonschulz5_batched,
235
- options={
236
- "triton.cudagraphs": True,
237
- "shape_padding": False
238
- })
239
- torch.compiler.cudagraph_mark_step_begin()
240
- return _ns_per_shape[key](G, steps).clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/pipeline.py DELETED
@@ -1,468 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- import torch
5
- import torch.distributed as dist
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- from .core import _muon_state, adjust_lr_for_muon
10
- from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
- from .qk_clip import compute_scales
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # ======================================================================
16
- # Stage helpers
17
- # ======================================================================
18
-
19
-
20
- def _launch_gather(
21
- params: list[DTensor],
22
- owned_params: list[DTensor],
23
- param_to_state: dict[int, _muon_state],
24
- rank: int,
25
- num_ranks: int,
26
- process_group: dist.ProcessGroup,
27
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
- """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
-
30
- Returns:
31
- work: Async operation handle.
32
- recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
- gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
- ``None`` for non-owned.
35
- recv_counts: Per-source-rank element counts.
36
- """
37
- # Allocate gathered-grad buffers
38
- gathered_grads: dict[int, torch.Tensor | None] = {}
39
- for p in params:
40
- state = param_to_state[id(p)]
41
- if rank == state.worker_rank:
42
- gathered_grads[id(p)] = torch.empty(p.shape,
43
- dtype=COMM_DTYPE,
44
- device="cuda")
45
- else:
46
- gathered_grads[id(p)] = None
47
-
48
- # Build send buffer – batch grad copies via torch.cat
49
- # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
- send_counts = [0] * num_ranks
51
- for p in params:
52
- state = param_to_state[id(p)]
53
- send_counts[state.worker_rank] += state.rank_numels[rank]
54
-
55
- total_send = sum(send_counts)
56
- if total_send > 0:
57
- # Group grad slices by destination rank in a single pass.
58
- dst_to_grads = [[] for _ in range(num_ranks)]
59
- for p in params:
60
- state = param_to_state[id(p)]
61
- n = state.rank_numels[rank]
62
- if n > 0:
63
- g = p.grad.to_local()
64
- dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
-
66
- # Flatten in dst order and cat once.
67
- all_slices = []
68
- for dst in range(num_ranks):
69
- all_slices.extend(dst_to_grads[dst])
70
- send_buf = torch.cat(all_slices)
71
- if send_buf.dtype != COMM_DTYPE:
72
- send_buf = send_buf.to(COMM_DTYPE)
73
- else:
74
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
-
76
- # Build recv buffer
77
- recv_counts = [0] * num_ranks
78
- for src in range(num_ranks):
79
- total = 0
80
- for p in owned_params:
81
- state = param_to_state[id(p)]
82
- assert state.worker_rank == rank
83
- total += state.rank_numels[src]
84
- recv_counts[src] = total
85
-
86
- recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
87
-
88
- # Launch async all-to-all
89
- logger.debug(f"send_buf size: {send_buf.numel()}, "
90
- f"recv_buf size: {recv_buf.numel()}, "
91
- f"recv_counts: {recv_counts}, "
92
- f"send_counts: {send_counts}, "
93
- f"process_group: {str(process_group)}")
94
- work = dist.all_to_all_single(
95
- recv_buf,
96
- send_buf,
97
- output_split_sizes=recv_counts,
98
- input_split_sizes=send_counts,
99
- group=process_group,
100
- async_op=True,
101
- )
102
-
103
- return work, recv_buf, gathered_grads, recv_counts
104
-
105
-
106
- def _complete_gather(
107
- recv_buf: torch.Tensor,
108
- recv_counts: list[int],
109
- owned_params: list[DTensor],
110
- gathered_grads: dict[int, torch.Tensor | None],
111
- param_to_state: dict[int, _muon_state],
112
- rank: int,
113
- ) -> None:
114
- """Reconstruct gathered grads from the recv buffer (in-place)."""
115
- off = 0
116
- for src in range(len(recv_counts)):
117
- if recv_counts[src] == 0:
118
- continue
119
-
120
- block = recv_counts[src]
121
- inner_off = 0
122
- for p in owned_params:
123
- state = param_to_state[id(p)]
124
- assert state.worker_rank == rank
125
-
126
- indices = state.rank_indices[src]
127
-
128
- shard_view = gathered_grads[id(p)][indices]
129
- n = shard_view.numel()
130
- if n == 0:
131
- continue
132
-
133
- sg = recv_buf.narrow(0, off + inner_off, n)
134
- sg = sg.reshape(shard_view.shape)
135
- gathered_grads[id(p)][indices] = sg
136
-
137
- inner_off += n
138
- assert inner_off == block
139
- off += block
140
-
141
-
142
- def _compute_ns(
143
- owned_params: list[DTensor],
144
- gathered_grads: dict[int, torch.Tensor | None],
145
- ns_steps: int,
146
- ) -> dict[int, torch.Tensor | None]:
147
- """Run Newton-Schulz orthogonalization on owned parameters.
148
-
149
- Returns:
150
- computed_us: ``{id(p): orthogonalized_update}`` for owned params.
151
- """
152
- computed_us: dict[int, torch.Tensor | None] = {}
153
- for p in owned_params:
154
- u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
- gathered_grads[id(p)] = None # free gathered grad
156
- computed_us[id(p)] = u
157
- return computed_us
158
-
159
-
160
- def _launch_scatter(
161
- params: list[DTensor],
162
- owned_params: list[DTensor],
163
- param_to_state: dict[int, _muon_state],
164
- rank: int,
165
- num_ranks: int,
166
- process_group: dist.ProcessGroup,
167
- computed_us: dict[int, torch.Tensor | None],
168
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
169
- """Allocate scatter buffers, build send/recv, and launch async all-to-all.
170
-
171
- Returns:
172
- work: Async operation handle.
173
- recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
- scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
- zero-copy views into ``recv_buf``.
176
- recv_counts: Per-source-rank element counts.
177
- """
178
- # scattered_us is populated by _complete_scatter with zero-copy views
179
- # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
- # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
- # so _update_params can iterate all params without KeyError.
182
- scattered_us: dict[int, torch.Tensor] = {}
183
- for p in params:
184
- if param_to_state[id(p)].rank_numels[rank] == 0:
185
- scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
- dtype=COMM_DTYPE)
187
-
188
- # Build send buffer – batch via torch.cat
189
- # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
- send_counts = [0] * num_ranks
191
- if owned_params:
192
- for p in owned_params:
193
- state = param_to_state[id(p)]
194
- for dst_rank in range(num_ranks):
195
- send_counts[dst_rank] += state.rank_numels[dst_rank]
196
-
197
- total_send = sum(send_counts)
198
- if total_send > 0:
199
- # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
- u_fulls = {}
201
- for p in owned_params:
202
- u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
-
204
- # Collect slices in dst order (matches all-to-all send layout).
205
- all_slices = []
206
- for dst_rank in range(num_ranks):
207
- for p in owned_params:
208
- state = param_to_state[id(p)]
209
- su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
- if su.numel() > 0:
211
- all_slices.append(su)
212
-
213
- send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
- 0, dtype=COMM_DTYPE, device="cuda")
215
- else:
216
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
-
218
- # Build recv buffer
219
- recv_counts = [0] * num_ranks
220
- for src in range(num_ranks):
221
- total = 0
222
- for p in params:
223
- state = param_to_state[id(p)]
224
- if state.worker_rank != src:
225
- continue
226
- total += state.rank_numels[rank]
227
- recv_counts[src] = total
228
-
229
- recv_total = sum(recv_counts)
230
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
-
232
- # Launch async all-to-all
233
- work = dist.all_to_all_single(
234
- recv_buf,
235
- send_buf,
236
- output_split_sizes=recv_counts,
237
- input_split_sizes=send_counts,
238
- group=process_group,
239
- async_op=True,
240
- )
241
-
242
- return work, recv_buf, scattered_us, recv_counts
243
-
244
-
245
- def _complete_scatter(
246
- recv_buf: torch.Tensor,
247
- recv_counts: list[int],
248
- params: list[DTensor],
249
- param_to_state: dict[int, _muon_state],
250
- rank: int,
251
- scattered_us: dict[int, torch.Tensor],
252
- ) -> None:
253
- """Populate scattered_us with zero-copy views into recv_buf.
254
-
255
- Instead of pre-allocating tensors and copying, we assign views directly
256
- from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
- The underlying storage of ``recv_buf`` is kept alive through the views
258
- until ``scattered_us`` is cleared after ``_update_params``.
259
- """
260
- off = 0
261
- for src in range(len(recv_counts)):
262
- block = recv_counts[src]
263
- if block == 0:
264
- continue
265
-
266
- inner_off = 0
267
- for p in params:
268
- state = param_to_state[id(p)]
269
- if state.worker_rank != src:
270
- continue
271
- n = state.rank_numels[rank]
272
- if n == 0:
273
- continue
274
-
275
- scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
- n).view_as(p.to_local())
277
-
278
- inner_off += n
279
-
280
- assert inner_off == block
281
- off += block
282
-
283
-
284
- def _update_params(
285
- params: list[DTensor],
286
- param_to_state: dict[int, _muon_state],
287
- rank: int,
288
- scattered_us: dict[int, torch.Tensor],
289
- lr: float,
290
- weight_decay: float,
291
- ) -> None:
292
- """Apply weight decay, Muon update, and optional QK clipping.
293
-
294
- Uses batched ``_foreach_mul_`` for weight decay and batched
295
- ``_foreach_add_`` for the Muon update, grouping parameters by
296
- adjusted_lr to minimize kernel launches while preserving float32
297
- precision for the alpha scaling.
298
- """
299
- if not params:
300
- return
301
-
302
- # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
- p_locals = [p._local_tensor for p in params]
304
- torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
-
306
- # Group params by adjusted_lr so _foreach_add_ can use a single
307
- # alpha per group (preserves float32 precision for alpha scaling).
308
- lr_groups: dict[float, tuple[list, list]] = {}
309
- for p in params:
310
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
- if adjusted_lr not in lr_groups:
312
- lr_groups[adjusted_lr] = ([], [])
313
- lr_groups[adjusted_lr][0].append(p._local_tensor)
314
- lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
-
316
- for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
- torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
-
319
- # QK clipping – applied directly on the local tensor to
320
- # avoid DTensor sharding-propagation issues with _StridedShard.
321
- for p in params:
322
- state = param_to_state[id(p)]
323
- if state.qk_clip_state is None:
324
- continue
325
- scales_full = compute_scales(p, state.qk_clip_state)
326
- if scales_full is not None:
327
- ratio = p.shape[0] // scales_full.shape[0]
328
- idx0 = state.rank_indices[rank][0]
329
- if isinstance(idx0, slice):
330
- start = idx0.start or 0
331
- idx0 = torch.arange(start,
332
- idx0.stop,
333
- device=scales_full.device)
334
- row_scales = scales_full[idx0 // ratio]
335
- p._local_tensor.mul_(row_scales.view(-1, 1))
336
-
337
-
338
- # ======================================================================
339
- # Pre-launch helper for overlapping first chunk's gather with other work.
340
- # ======================================================================
341
-
342
-
343
- @torch.no_grad()
344
- def prelaunch_first_gather(
345
- params: list[DTensor],
346
- param_to_state: dict[int, _muon_state],
347
- rank: int,
348
- none_grad: bool,
349
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
- """Launch the first chunk's A2A gather early for overlap with other compute.
351
-
352
- Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
- the NCCL all-to-all runs concurrently on the NCCL stream while the
354
- default stream executes compute.
355
-
356
- Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
- be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
- """
359
- process_group = param_to_state[id(params[0])].process_group
360
- num_ranks = dist.get_world_size(group=process_group)
361
- owned_params = [
362
- p for p in params if param_to_state[id(p)].worker_rank == rank
363
- ]
364
-
365
- with record_function("muon::prelaunch_gather"):
366
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
- params, owned_params, param_to_state, rank, num_ranks,
368
- process_group)
369
-
370
- if none_grad:
371
- for p in params:
372
- p.grad = None
373
-
374
- return work, recv_buf, gathered_grads, recv_counts
375
-
376
-
377
- # ======================================================================
378
- # Main generator – thin orchestrator that wires stages together.
379
- # ======================================================================
380
-
381
-
382
- @torch.no_grad()
383
- def muon_chunk_pipeline(
384
- params: list[DTensor],
385
- param_to_state: dict[int, _muon_state],
386
- rank: int,
387
- ns_steps: int,
388
- lr: float,
389
- weight_decay: float,
390
- none_grad: bool,
391
- prelaunch_gather: tuple | None = None,
392
- ) -> Generator[None, None, None]:
393
- """Process one chunk of parameters through the full Muon pipeline.
394
-
395
- Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
396
-
397
- Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
398
- that communication and computation overlap across chunks. Async
399
- communication is launched via ``async_op=True`` and completed after
400
- the yield with ``work.wait()``.
401
-
402
- Overlap happens because :func:`run_pipeline` admits one new chunk
403
- per iteration (staggered admission). While chunk *N* does NS
404
- compute on the default CUDA stream, chunk *N+1*'s async all-to-all
405
- runs concurrently on the NCCL stream — no separate ``comm_stream``
406
- is required.
407
-
408
- If ``prelaunch_gather`` is provided, the gather was already launched
409
- by :func:`prelaunch_first_gather` and we skip launching it again.
410
-
411
- Yields exactly **2** times:
412
-
413
- 1. After launching async all-to-all gather (or immediately if pre-launched).
414
- 2. After launching async all-to-all scatter.
415
- """
416
- process_group = param_to_state[id(params[0])].process_group
417
- num_ranks = dist.get_world_size(group=process_group)
418
- owned_params = [
419
- p for p in params if param_to_state[id(p)].worker_rank == rank
420
- ]
421
-
422
- if prelaunch_gather is not None:
423
- # Gather was pre-launched; none_grad already handled by caller.
424
- work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
- else:
426
- # Normal path: launch async gather.
427
- with record_function("muon::launch_gather"):
428
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
- params, owned_params, param_to_state, rank, num_ranks,
430
- process_group)
431
-
432
- if none_grad:
433
- for p in params:
434
- p.grad = None
435
-
436
- yield # --- YIELD 1: other chunks can launch their gather ---
437
-
438
- with record_function("muon::wait_gather"):
439
- work.wait()
440
- _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
441
- param_to_state, rank)
442
- del recv_buf
443
-
444
- # Stage 3: Newton-Schulz orthogonalization.
445
- with record_function("muon::newton_schulz"):
446
- computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
447
- gathered_grads.clear()
448
-
449
- # Stages 4-5: launch async scatter.
450
- with record_function("muon::launch_scatter"):
451
- work, recv_buf, scattered_us, recv_counts = _launch_scatter(
452
- params, owned_params, param_to_state, rank, num_ranks,
453
- process_group, computed_us)
454
- computed_us.clear()
455
-
456
- yield # --- YIELD 2: other chunks can launch their scatter ---
457
-
458
- with record_function("muon::wait_scatter"):
459
- work.wait()
460
- _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
461
- scattered_us)
462
- del recv_buf
463
-
464
- # Stage 6: apply parameter updates.
465
- with record_function("muon::update_params"):
466
- _update_params(params, param_to_state, rank, scattered_us, lr,
467
- weight_decay)
468
- scattered_us.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py DELETED
@@ -1,198 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
-
8
- from .core import normalize_fqn
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
- """
15
- Parse a parameter name to check if it is a query/key projection layer
16
- and return (kind, layer_index).
17
-
18
- Supported kinds:
19
- MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
- MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
-
22
- Returns:
23
- (kind, layer_idx) or (None, -1) if not matched.
24
-
25
- Example:
26
- 'model.3.attn.wq.weight' -> ('wq', 3)
27
- 'model.5.attn.wk.weight' -> ('wk', 5)
28
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
- 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
- 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
- 'model.4.attn.v_proj.weight' -> (None, -1)
33
- """
34
- parts = normalize_fqn(name).split('.')
35
- if len(parts) < 3:
36
- return None, -1
37
-
38
- kind = parts[-2]
39
-
40
- layer_idx = -1
41
- for part in reversed(parts):
42
- if part.isdigit():
43
- layer_idx = int(part)
44
- break
45
-
46
- if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
- return kind, layer_idx
48
-
49
- return None, -1
50
-
51
-
52
- @dataclass
53
- class QKClipInfo:
54
- """Per-parameter dynamic info computed from config + runtime logits."""
55
- kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
- indices: list[int] # which heads to consider for clipping
57
- head_dim: int # from config (qk_head_dim for MLA wq_b)
58
- threshold: float # from config
59
- logit: torch.Tensor | None
60
-
61
- # MLA-specific fields
62
- is_mla: bool = False
63
- qk_nope_head_dim: int = 0
64
- qk_rope_head_dim: int = 0
65
- v_head_dim: int = 0
66
-
67
-
68
- def get_qk_clip_info(clip_config, n, qk_logits):
69
- """Extract QK clipping info for a named parameter.
70
-
71
- Args:
72
- clip_config: QK clipping configuration dict (or None).
73
- MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
- MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
- n: Parameter name string.
76
- qk_logits: Dict mapping layer indices to logit tensors (or None).
77
-
78
- Returns:
79
- QKClipInfo instance with clipping configuration for this parameter.
80
- """
81
- if clip_config is None:
82
- return None
83
-
84
- head_dim = clip_config.get('head_dim')
85
- threshold = clip_config.get('threshold')
86
- kind, layer_idx = parse_qk_layer(n)
87
- is_mla = clip_config.get('is_mla', False)
88
-
89
- logit, indices = None, []
90
- if qk_logits is not None and kind is not None:
91
- logit = qk_logits[layer_idx]
92
- if isinstance(logit, DTensor):
93
- # In TP settings, qk_logits may be DTensor
94
- # We convert it to full tensor here for simplicity
95
- logit = logit.full_tensor()
96
-
97
- if kind in ('wq_b', 'wq', 'q_proj'):
98
- indices = clip_config.get('q_indices', []) or []
99
- elif kind in ('wkv_b', 'wk', 'k_proj'):
100
- indices = clip_config.get('k_indices', []) or []
101
-
102
- if is_mla:
103
- return QKClipInfo(
104
- kind=kind,
105
- indices=indices,
106
- head_dim=head_dim,
107
- threshold=threshold,
108
- logit=logit,
109
- is_mla=True,
110
- qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
- qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
- v_head_dim=clip_config['v_head_dim'],
113
- )
114
- else:
115
- return QKClipInfo(
116
- kind=kind,
117
- indices=indices,
118
- head_dim=head_dim,
119
- threshold=threshold,
120
- logit=logit,
121
- )
122
-
123
-
124
- def compute_scales(p, qk_clip_state):
125
- """Compute per-head scaling factors for QK clipping.
126
-
127
- Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
- For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
- """
130
- kind = qk_clip_state.kind
131
- indices = qk_clip_state.indices
132
- head_dim = qk_clip_state.head_dim
133
- threshold = qk_clip_state.threshold
134
- logit = qk_clip_state.logit
135
-
136
- # Check if any head exceeds threshold before allocating.
137
- head_scales = {}
138
- for logit_idx, head_idx in enumerate(indices):
139
- v_ele = float(logit[logit_idx])
140
- if v_ele > threshold:
141
- new_scale = math.sqrt(threshold / v_ele)
142
- if head_idx not in head_scales or new_scale < head_scales[head_idx]:
143
- head_scales[head_idx] = new_scale
144
- logger.info(
145
- f"[{kind}] Head {head_idx} exceeded threshold "
146
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
147
- )
148
-
149
- if not head_scales:
150
- return None
151
-
152
- # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
- if qk_clip_state.is_mla and kind == 'wkv_b':
154
- effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
- else:
156
- effective_head_dim = head_dim
157
-
158
- H_global = p.shape[0] // effective_head_dim
159
- scales_full = torch.ones(H_global, device=p.data.device)
160
- for head_idx, scale in head_scales.items():
161
- scales_full[head_idx] = scale
162
- return scales_full
163
-
164
-
165
- def qk_clip(p, scales, info):
166
- """Apply per-head scaling to a Q/K projection weight matrix.
167
-
168
- Args:
169
- p: Parameter (nn.Parameter or raw tensor).
170
- scales: [n_heads] tensor, each element = √γ_h.
171
- info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
-
173
- MLA sub-region scaling per Algorithm 1 (MuonClip):
174
- wq_b: q_nope rows → √γ, q_pe rows → γ
175
- wkv_b: k_nope rows → √γ, v rows → unchanged
176
- """
177
- W = p.data if isinstance(p, torch.nn.Parameter) else p
178
-
179
- if not info.is_mla:
180
- # MHA/GQA: uniform √γ applied to all rows in each head
181
- W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
- return
183
-
184
- # MLA: vectorized sub-region scaling within each head
185
- if info.kind == 'wq_b':
186
- qk_nope = info.qk_nope_head_dim
187
- qk_head_dim = qk_nope + info.qk_rope_head_dim
188
- W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ
190
- W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
- 1)) # q_pe → γ
192
-
193
- elif info.kind == 'wkv_b':
194
- qk_nope = info.qk_nope_head_dim
195
- kv_stride = qk_nope + info.v_head_dim
196
- W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ
198
- # v rows: not touched (k_R shared rotary unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/adamw.py DELETED
@@ -1,271 +0,0 @@
1
- import logging
2
- from collections import defaultdict
3
- from typing import cast
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def fused_adamw(
13
- params: list[torch.Tensor],
14
- grads: list[torch.Tensor],
15
- exp_avgs: list[torch.Tensor],
16
- exp_avg_sqs: list[torch.Tensor],
17
- max_exp_avg_sqs: list[torch.Tensor],
18
- state_steps: list[torch.Tensor],
19
- amsgrad: bool,
20
- beta1: float,
21
- beta2: float,
22
- lr: float | torch.Tensor,
23
- weight_decay: float,
24
- eps: float,
25
- maximize: bool,
26
- ) -> None:
27
- if not params:
28
- return
29
-
30
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
31
- # treating it as a scalar.
32
- lr_dict: dict | None = ({
33
- lr.device: lr
34
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
35
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
36
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
37
- state_steps] # type: ignore[list-item]
38
- )
39
- for (device, _), (
40
- (
41
- device_params_,
42
- device_grads_,
43
- device_exp_avgs_,
44
- device_exp_avg_sqs_,
45
- device_max_exp_avg_sqs,
46
- device_state_steps_,
47
- ),
48
- _,
49
- ) in grouped_tensors.items():
50
- device_params = cast(list[torch.Tensor], device_params_)
51
- device_grads = cast(list[torch.Tensor], device_grads_)
52
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
53
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
54
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
55
-
56
- if lr_dict is not None and device not in lr_dict:
57
- lr_dict[device] = lr.to(
58
- device=device, non_blocking=True) # type: ignore[union-attr]
59
- lr = lr_dict[device]
60
- torch._foreach_add_(device_state_steps, 1)
61
- func = torch._fused_adamw_
62
- func(
63
- device_params,
64
- device_grads,
65
- device_exp_avgs,
66
- device_exp_avg_sqs,
67
- device_max_exp_avg_sqs, # type: ignore[arg-type]
68
- device_state_steps,
69
- amsgrad=amsgrad,
70
- lr=lr, # type: ignore[arg-type]
71
- beta1=beta1,
72
- beta2=beta2,
73
- weight_decay=weight_decay,
74
- eps=eps,
75
- maximize=maximize,
76
- )
77
-
78
-
79
- def _to_local(t):
80
- """Unwrap DTensor to local tensor for fused ops."""
81
- return t._local_tensor if isinstance(t, DTensor) else t
82
-
83
-
84
- # ---------------------------------------------------------------------------
85
- # Caches for eliminating per-step Python overhead.
86
- #
87
- # Placement grouping and tensor list assembly are identical every step
88
- # (params don't change placement, moment/step tensors are the same objects
89
- # after initialisation). We cache them keyed by id() of the param list
90
- # stored in param_groups (stable across steps).
91
- #
92
- # Only gradients change each step and must be collected fresh.
93
- # ---------------------------------------------------------------------------
94
-
95
- # id(group["params"]) → dict[placement_key, list[param]]
96
- _placement_cache: dict[int, dict[tuple, list]] = {}
97
-
98
- # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
- _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
-
101
-
102
- def _step_adamw_params_slow(optimizer_state, params, group):
103
- """Uncached fallback for the rare case where some params lack grads."""
104
- params_with_grads = []
105
- grads = []
106
- moment1 = []
107
- moment2 = []
108
- state_steps = []
109
-
110
- for p in params:
111
- g = p.grad
112
- if g is None:
113
- continue
114
- state = optimizer_state[p]
115
- params_with_grads.append(_to_local(p))
116
- grads.append(_to_local(g))
117
- if "step" not in state:
118
- state["step"] = torch.zeros((),
119
- dtype=torch.float32,
120
- device=p.device)
121
- state["moment1"] = torch.zeros_like(g)
122
- state["moment2"] = torch.zeros_like(g)
123
- moment1.append(_to_local(state["moment1"]))
124
- moment2.append(_to_local(state["moment2"]))
125
- if not isinstance(state["step"], torch.Tensor):
126
- state["step"] = torch.tensor(state["step"],
127
- dtype=torch.float32,
128
- device=p.device)
129
- state_steps.append(state["step"])
130
-
131
- if not params_with_grads:
132
- return
133
-
134
- lr = group["lr"]
135
- beta1, beta2 = group["adamw_betas"]
136
- eps = group["adamw_eps"]
137
- weight_decay = group["weight_decay"]
138
-
139
- fused_adamw(
140
- params_with_grads,
141
- grads,
142
- moment1,
143
- moment2,
144
- [],
145
- state_steps,
146
- amsgrad=False,
147
- beta1=beta1,
148
- beta2=beta2,
149
- lr=lr,
150
- weight_decay=weight_decay,
151
- eps=eps,
152
- maximize=False,
153
- )
154
-
155
-
156
- def step_adamw_params(optimizer_state, params, group):
157
- """Run fused AdamW on a list of parameters sharing the same placement.
158
-
159
- After the first call, cached tensor lists (params_local, moment1,
160
- moment2, state_steps) are reused — only gradients are collected fresh.
161
-
162
- Args:
163
- optimizer_state: The optimizer's state dict (self.state in Muon).
164
- params: List of parameters to update.
165
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
- """
167
- # Collect grads — the only thing that changes each step.
168
- with record_function("adamw::collect_grads"):
169
- grads = []
170
- for p in params:
171
- g = p.grad
172
- if g is None:
173
- # Rare: fall back to slow path that filters per-param.
174
- _step_adamw_params_slow(optimizer_state, params, group)
175
- return
176
- grads.append(_to_local(g))
177
-
178
- tensor_key = id(params)
179
- if tensor_key not in _tensor_cache:
180
- with record_function("adamw::init_tensor_cache"):
181
- params_local = []
182
- moment1 = []
183
- moment2 = []
184
- state_steps = []
185
-
186
- for p in params:
187
- state = optimizer_state[p]
188
- params_local.append(_to_local(p))
189
- if "step" not in state:
190
- state["step"] = torch.zeros((),
191
- dtype=torch.float32,
192
- device=p.device)
193
- state["moment1"] = torch.zeros_like(p.grad)
194
- state["moment2"] = torch.zeros_like(p.grad)
195
- moment1.append(_to_local(state["moment1"]))
196
- moment2.append(_to_local(state["moment2"]))
197
- if not isinstance(state["step"], torch.Tensor):
198
- state["step"] = torch.tensor(state["step"],
199
- dtype=torch.float32,
200
- device=p.device)
201
- state_steps.append(state["step"])
202
-
203
- _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
- state_steps)
205
-
206
- params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
-
208
- lr = group["lr"]
209
- beta1, beta2 = group["adamw_betas"]
210
- eps = group["adamw_eps"]
211
- weight_decay = group["weight_decay"]
212
-
213
- with record_function("adamw::fused_adamw"):
214
- fused_adamw(
215
- params_local,
216
- grads,
217
- moment1,
218
- moment2,
219
- [],
220
- state_steps,
221
- amsgrad=False,
222
- beta1=beta1,
223
- beta2=beta2,
224
- lr=lr,
225
- weight_decay=weight_decay,
226
- eps=eps,
227
- maximize=False,
228
- )
229
-
230
-
231
- def step_adamw(optimizer_state, group):
232
- """Dispatch AdamW step, grouping parameters by type and placement.
233
-
234
- Placement grouping is cached after the first call since params never
235
- change their placement between steps.
236
-
237
- Args:
238
- optimizer_state: The optimizer's state dict (self.state in Muon).
239
- group: Parameter group dict.
240
- """
241
- params = group["params"]
242
- placement_key = id(params)
243
-
244
- if placement_key not in _placement_cache:
245
- with record_function("adamw::group_by_placement"):
246
- placement_to_params: dict[tuple,
247
- list[torch.Tensor]] = defaultdict(list)
248
- for p in params:
249
- match p:
250
- case DTensor():
251
- logger.debug(
252
- "[AdamW] DTensor param: shape=%s, placements=%s, "
253
- "mesh=%s, grad=%s", p.shape, p.placements,
254
- p.device_mesh.mesh_dim_names,
255
- p.grad.shape if p.grad is not None else None)
256
- placement_to_params[tuple(
257
- [p.placements, p.device_mesh])].append(p)
258
- case torch.Tensor():
259
- logger.debug(
260
- "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
- p.grad.shape if p.grad is not None else None)
262
- placement_to_params[tuple([torch.Tensor,
263
- None])].append(p)
264
-
265
- logger.debug("[AdamW] %d placement groups, %d total params",
266
- len(placement_to_params), len(params))
267
-
268
- _placement_cache[placement_key] = dict(placement_to_params)
269
-
270
- for group_params in _placement_cache[placement_key].values():
271
- step_adamw_params(optimizer_state, group_params, group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/async_utils.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- logger = logging.getLogger(__name__)
5
-
6
-
7
- class _Task:
8
- """Internal: wraps a generator, advances one yield at a time."""
9
-
10
- def __init__(self, generator: Generator[None, None, None], index: int):
11
- self._generator = generator
12
- self._index = index
13
- self._steps_completed = 0
14
- self.step() # run to first yield
15
-
16
- def step(self) -> bool:
17
- try:
18
- next(self._generator)
19
- self._steps_completed += 1
20
- logger.debug("pipeline[%d] completed stage %d", self._index,
21
- self._steps_completed)
22
- return True
23
- except StopIteration:
24
- logger.debug("pipeline[%d] finished after %d stages", self._index,
25
- self._steps_completed)
26
- return False
27
-
28
- def close(self):
29
- self._generator.close()
30
-
31
-
32
- def run_pipeline(
33
- pipelines: Generator[Generator[None, None, None], None, None],
34
- max_concurrent: int,
35
- ) -> None:
36
- """Run generator-based pipelines with bounded concurrency.
37
-
38
- Each pipeline is a generator that yields at stage boundaries.
39
- The runtime interleaves pipelines so communication and computation
40
- overlap across chunks.
41
- """
42
- if max_concurrent <= 0:
43
- raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
-
45
- have_new = True
46
- task_index = 0
47
- previous_tasks: list[_Task] = []
48
-
49
- try:
50
- while have_new or previous_tasks:
51
- running_tasks: list[_Task] = []
52
-
53
- # Admit one new pipeline per iteration (staggered admission).
54
- # Admitting one at a time ensures that while chunk N does NS
55
- # compute on the default stream, chunk N+1's NCCL all-to-all
56
- # runs concurrently on the NCCL stream — creating real
57
- # communication/computation overlap on the GPU.
58
- if have_new and len(previous_tasks) < max_concurrent:
59
- try:
60
- gen = next(pipelines)
61
- task = _Task(gen, task_index)
62
- task_index += 1
63
- running_tasks.append(task)
64
- except StopIteration:
65
- have_new = False
66
-
67
- # Advance every previously-yielded task by one step.
68
- for task in previous_tasks:
69
- if task.step():
70
- running_tasks.append(task)
71
-
72
- previous_tasks = running_tasks
73
- except BaseException:
74
- # Clean up all in-flight generators to release GPU resources.
75
- for task in previous_tasks:
76
- task.close()
77
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/core.py DELETED
@@ -1,219 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
- from typing import List
5
-
6
- import torch
7
- from torch.distributed import ProcessGroup
8
- from torch.distributed.tensor import DTensor
9
-
10
- # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
- # parameter FQNs. Activation checkpointing similarly inserts
12
- # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
- # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
- _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def normalize_fqn(name: str) -> str:
20
- """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
- return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
-
23
-
24
- @dataclass
25
- class _muon_state:
26
- worker_rank: int
27
- process_group: ProcessGroup
28
- rank_indices: dict[int, tuple] # local_rank -> per-dim indices
29
- rank_numels: dict[int, int] # local_rank -> numel
30
- name: str
31
- qk_clip_state: torch.Tensor | None = None
32
-
33
-
34
- def _batch_momentum(
35
- grads: List[torch.Tensor],
36
- momentum_bufs: List[torch.Tensor],
37
- momentum: torch.Tensor,
38
- ) -> None:
39
- """Batched momentum update (no nesterov)."""
40
- torch._foreach_mul_(momentum_bufs, momentum)
41
- torch._foreach_add_(momentum_bufs, grads)
42
-
43
-
44
- def _batch_momentum_nesterov(
45
- grads: List[torch.Tensor],
46
- momentum_bufs: List[torch.Tensor],
47
- momentum: torch.Tensor,
48
- ) -> None:
49
- """Batched momentum update with nesterov correction."""
50
- torch._foreach_mul_(momentum_bufs, momentum)
51
- torch._foreach_add_(momentum_bufs, grads)
52
- nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
- torch._foreach_add_(grads, nesterov_terms)
54
-
55
-
56
- _compiled_momentum: dict[bool, callable] = {}
57
- _use_momentum_compile = True
58
-
59
-
60
- def set_momentum_compile(enabled: bool):
61
- """Toggle torch.compile for batched momentum."""
62
- global _use_momentum_compile
63
- _use_momentum_compile = enabled
64
-
65
-
66
- def batch_pre_ortho(
67
- grads: List[torch.Tensor],
68
- momentum_bufs: List[torch.Tensor],
69
- momentum: torch.Tensor,
70
- nesterov: bool,
71
- ) -> None:
72
- """Batched momentum update on lists of plain tensors.
73
-
74
- Mirrors dion's ``muon_update_pre_orthogonalize``.
75
- Inputs must be plain CUDA tensors (not DTensor).
76
- Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
-
78
- When compile is enabled, uses separately compiled functions for
79
- nesterov=True/False to avoid graph breaks from the branch.
80
- """
81
- fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
- if _use_momentum_compile:
83
- if nesterov not in _compiled_momentum:
84
- _compiled_momentum[nesterov] = torch.compile(fn)
85
- fn = _compiled_momentum[nesterov]
86
- fn(grads, momentum_bufs, momentum)
87
-
88
-
89
- def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
- """Weight-decay + update on plain tensors.
91
-
92
- Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
- lookup per call × 256+ params = massive overhead. The pipeline path uses
94
- batched _foreach_* ops instead; this function remains for base() and
95
- distributed_muon().
96
- """
97
- p_data.mul_(1 - lr * weight_decay)
98
- p_data.add_(u_data, alpha=-adjusted_lr)
99
-
100
-
101
- def update_p(p, u, lr, adjusted_lr, weight_decay):
102
- """Apply weight decay and orthogonalized update to parameter.
103
-
104
- Args:
105
- p: Parameter (torch.nn.Parameter or DTensor).
106
- u: Orthogonalized update tensor.
107
- lr: Base learning rate.
108
- adjusted_lr: Size-adjusted learning rate.
109
- weight_decay: Weight decay coefficient.
110
- """
111
- # Unwrap Parameter -> underlying data tensor.
112
- p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
- # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
- if isinstance(p_data, DTensor):
115
- p_data = p_data._local_tensor
116
- u_data = u._local_tensor if isinstance(u, DTensor) else u
117
- _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
118
-
119
-
120
- def adjust_lr_for_muon(lr, param_shape):
121
- """Scale learning rate based on parameter matrix dimensions.
122
-
123
- Args:
124
- lr: Base learning rate.
125
- param_shape: Shape of the parameter tensor.
126
-
127
- Returns:
128
- Adjusted learning rate.
129
- """
130
- A, B = param_shape[:2]
131
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
132
- # as described in the paper
133
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
134
- adjusted_lr = lr * adjusted_ratio
135
- return adjusted_lr
136
-
137
-
138
- def _match_key(parts, key):
139
- """Check if key matches as contiguous components in parts.
140
-
141
- Single-component keys (e.g. "experts") match any single component.
142
- Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
- """
144
- key_parts = key.split(".")
145
- key_len = len(key_parts)
146
- if key_len == 1:
147
- return key in parts
148
- return any(parts[i:i + key_len] == key_parts
149
- for i in range(len(parts) - key_len + 1))
150
-
151
-
152
- def is_expert_param(name, expert_keys):
153
- """Check if a parameter name matches any expert key (component-level)."""
154
- if not expert_keys:
155
- return False
156
- parts = normalize_fqn(name).split(".")
157
- return any(_match_key(parts, key) for key in expert_keys)
158
-
159
-
160
- def default_is_muon(name, x, expert_keys=None):
161
- normalized = normalize_fqn(name)
162
- parts = normalized.split(".")
163
- skip_keys = [
164
- "embed_tokens",
165
- "lm_head",
166
- "tok_embeddings",
167
- "output",
168
- "mhc_attn",
169
- "mhc_ffn",
170
- "lambda_proj",
171
- ]
172
- if any(key in parts for key in skip_keys):
173
- logger.info(
174
- "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
- normalized, name, x.ndim)
176
- return False
177
- effective_ndim = x.ndim
178
- is_expert = is_expert_param(name, expert_keys)
179
- if is_expert:
180
- effective_ndim -= 1
181
- result = effective_ndim >= 2
182
- logger.info(
183
- "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
- normalized, name, x.ndim, is_expert, effective_ndim,
185
- "Muon" if result else "AdamW")
186
- return result
187
-
188
-
189
- def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
190
- if is_muon_func is None:
191
- is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
-
193
- muon_params, muon_names = [], []
194
- non_muon_params, non_muon_names = [], []
195
-
196
- for n, p in model.named_parameters():
197
- if not p.requires_grad:
198
- continue
199
- if is_muon_func(n, p):
200
- muon_params.append(p)
201
- muon_names.append(n)
202
- else:
203
- non_muon_params.append(p)
204
- non_muon_names.append(n)
205
-
206
- logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
- expert_keys, len(muon_names), len(non_muon_names))
208
-
209
- return [
210
- {
211
- "params": muon_params,
212
- "names": muon_names,
213
- "use_muon": True,
214
- },
215
- {
216
- "params": non_muon_params,
217
- "use_muon": False,
218
- },
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py DELETED
@@ -1,206 +0,0 @@
1
- """CPU offloading for optimizer states.
2
-
3
- Manages a pinned CPU memory pool and async CUDA streams to offload
4
- optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
- optimizer steps, freeing GPU memory.
6
-
7
- All tracked tensors are packed into a single flat pinned CPU buffer
8
- (per dtype). D2H and H2D copies are performed per-tensor directly
9
- between individual GPU tensors and their slice of the CPU flat buffer
10
- — no GPU staging buffer is allocated, so there is **no temporary GPU
11
- memory spike** during offload or reload.
12
-
13
- Individual tensor storages are freed after offload via
14
- ``untyped_storage().resize_(0)``, preserving tensor identity so
15
- downstream caches remain valid.
16
- """
17
-
18
- import logging
19
- from collections import defaultdict
20
-
21
- import torch
22
- from torch.distributed.tensor import DTensor
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- class CPUOffloadPool:
28
- """Pinned CPU memory pool for async optimizer state offloading.
29
-
30
- Tracked tensors are grouped by dtype. Each group gets a single flat
31
- pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
- the flat buffer) to avoid allocating a GPU staging buffer.
33
- """
34
-
35
- def __init__(self):
36
- self._managed: list[torch.Tensor] = []
37
- self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
-
39
- # Per-dtype group: populated on first offload.
40
- # dtype → dict with keys:
41
- # "indices" : list[int] managed-list indices
42
- # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
- # "total" : int total numel
44
- # "cpu_flat" : Tensor pinned CPU buffer
45
- self._groups: dict[torch.dtype, dict] = {}
46
-
47
- self._offload_stream: torch.cuda.Stream | None = None
48
- self._device: torch.device | None = None
49
- self._initialized: bool = False
50
- self._logged: bool = False
51
-
52
- # ------------------------------------------------------------------
53
- @staticmethod
54
- def _local(t: torch.Tensor) -> torch.Tensor:
55
- """Unwrap DTensor to its local CUDA tensor."""
56
- return t._local_tensor if isinstance(t, DTensor) else t
57
-
58
- def _ensure_stream(self):
59
- if self._offload_stream is None:
60
- self._offload_stream = torch.cuda.Stream(device=self._device)
61
-
62
- # ------------------------------------------------------------------
63
- def track(self, tensor: torch.Tensor):
64
- """Register a GPU tensor for CPU offloading. Idempotent."""
65
- tid = id(tensor)
66
- if tid in self._storage_nbytes:
67
- return
68
- local = self._local(tensor)
69
- if self._device is None:
70
- self._device = local.device
71
- storage = local.untyped_storage()
72
- # Skip tensors with empty storage (e.g. empty FSDP shards)
73
- if storage.size() == 0:
74
- return
75
- self._storage_nbytes[tid] = storage.size()
76
- self._managed.append(tensor)
77
-
78
- # ------------------------------------------------------------------
79
- def _init_buffers(self):
80
- """Build per-dtype flat buffers on first offload."""
81
- # Group managed tensors by dtype.
82
- dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
83
- for idx, t in enumerate(self._managed):
84
- local = self._local(t)
85
- dtype_map[local.dtype].append((idx, local.numel()))
86
-
87
- total_cpu_bytes = 0
88
- for dtype, entries in dtype_map.items():
89
- offsets: list[tuple[int, int]] = []
90
- indices: list[int] = []
91
- off = 0
92
- for idx, n in entries:
93
- indices.append(idx)
94
- offsets.append((off, n))
95
- off += n
96
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
97
- self._groups[dtype] = {
98
- "indices": indices,
99
- "offsets": offsets,
100
- "total": off,
101
- "cpu_flat": cpu_flat,
102
- }
103
- total_cpu_bytes += off * cpu_flat.element_size()
104
-
105
- self._initialized = True
106
- logger.info(
107
- "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
108
- "%.2f MB pinned CPU memory",
109
- len(self._managed),
110
- len(self._groups),
111
- total_cpu_bytes / (1024**2),
112
- )
113
-
114
- # ------------------------------------------------------------------
115
- def offload(self):
116
- """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
117
- if not self._managed:
118
- return
119
- if not self._initialized:
120
- self._init_buffers()
121
- self._ensure_stream()
122
-
123
- # Offload stream waits for compute to finish.
124
- compute_event = torch.cuda.current_stream(self._device).record_event()
125
- self._offload_stream.wait_event(compute_event)
126
-
127
- offloaded_bytes = 0
128
-
129
- # Per-tensor D2H copies directly into CPU flat buffer slices.
130
- # No GPU staging buffer → no temporary GPU memory spike.
131
- with torch.cuda.stream(self._offload_stream):
132
- for dtype, grp in self._groups.items():
133
- indices = grp["indices"]
134
- offsets = grp["offsets"]
135
- cpu_flat = grp["cpu_flat"]
136
-
137
- for i, mgd_idx in enumerate(indices):
138
- local = self._local(self._managed[mgd_idx])
139
- off, n = offsets[i]
140
- cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
141
-
142
- offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
-
144
- # Wait for all D2H copies to land, then free GPU storage.
145
- self._offload_stream.synchronize()
146
- for t in self._managed:
147
- storage = self._local(t).untyped_storage()
148
- if storage.size() != 0:
149
- storage.resize_(0)
150
- else:
151
- raise RuntimeError(
152
- f"Tensor storage is already freed (size=0) before offload. "
153
- f"This indicates a double-free or external interference. "
154
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
155
- )
156
-
157
- if not self._logged:
158
- logger.info(
159
- "[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
160
- offloaded_bytes / (1024**2),
161
- )
162
-
163
- # ------------------------------------------------------------------
164
- def reload(self):
165
- """Per-tensor H2D from CPU flat buffer on the default stream.
166
-
167
- Runs on the current (default) CUDA stream to avoid stream
168
- interaction issues with the parallel Muon pipeline. Since
169
- pinned CPU memory is the source, the copies overlap with
170
- GPU idle time between steps.
171
- """
172
- if not self._managed or not self._initialized:
173
- return
174
-
175
- reloaded_bytes = 0
176
-
177
- # Re-allocate all GPU storages first.
178
- for t in self._managed:
179
- local = self._local(t)
180
- storage = local.untyped_storage()
181
- if storage.size() != 0:
182
- raise RuntimeError(
183
- f"Storage should have been freed (size=0) before reload, "
184
- f"but got size={storage.size()}. "
185
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
186
- )
187
- storage.resize_(self._storage_nbytes[id(t)])
188
-
189
- # Per-tensor H2D copies from CPU flat buffer slices.
190
- # non_blocking=True with pinned source allows DMA overlap.
191
- for dtype, grp in self._groups.items():
192
- indices = grp["indices"]
193
- offsets = grp["offsets"]
194
- cpu_flat = grp["cpu_flat"]
195
-
196
- for i, mgd_idx in enumerate(indices):
197
- local = self._local(self._managed[mgd_idx])
198
- off, n = offsets[i]
199
- local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
200
-
201
- reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
-
203
- if not self._logged:
204
- logger.info(
205
- "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2)
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def _is_shard(placement: Placement) -> bool:
11
- """Check if a placement is a shard type (Shard or _StridedShard).
12
-
13
- In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so
14
- ``placement.is_shard()`` returns False for _StridedShard. This helper
15
- handles both old and new hierarchies.
16
- """
17
- return isinstance(placement, (Shard, _StridedShard))
18
-
19
-
20
- def get_slices_of_dtensor(
21
- target: DTensor | torch.Tensor,
22
- local_rank: int,
23
- shard_mesh: DeviceMesh,
24
- shard_placements: tuple[Placement],
25
- ) -> tuple[slice | torch.Tensor, ...]:
26
- """
27
- Get per-dimension indices for a given rank's shard of the target tensor.
28
-
29
- Uses ``Shard.local_shard_size_and_offset`` and
30
- ``_StridedShard.local_shard_size_and_offset`` for correct handling of
31
- both contiguous and strided (non-contiguous) sharding.
32
-
33
- Args:
34
- target (DTensor | torch.Tensor): The target tensor (for its shape).
35
- local_rank (int): The local rank within the shard group.
36
- shard_mesh (DeviceMesh): The shard mesh (only shard dimensions).
37
- shard_placements (tuple[Placement]): The shard placements.
38
-
39
- Returns:
40
- A tuple of indices (one per tensor dim). Each element is either:
41
- - A ``slice`` (for contiguous or unsharded dims)
42
- - A 1-D ``torch.LongTensor`` of indices (for strided sharding)
43
- """
44
-
45
- # find the global rank of the local rank in the shard mesh
46
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
47
-
48
- rank_coords = (shard_mesh.mesh == rank).nonzero()
49
-
50
- assert len(rank_coords) == 1
51
- rank_coords = tuple(rank_coords[0].tolist())
52
-
53
- assert len(rank_coords) == len(shard_placements)
54
-
55
- # Track per-shard-dim indices.
56
- # None means "not yet sharded on this dim".
57
- dim_indices: dict[int, torch.Tensor] = {}
58
-
59
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
60
- # left-to-right sharding. This is ensured by the sorting logic of
61
- # construct_shard_mesh function.
62
- for mesh_dim_idx, (rank_coord, placement) in enumerate(
63
- zip(rank_coords, shard_placements)):
64
- assert _is_shard(placement)
65
-
66
- num_chunks = shard_mesh.mesh.shape[mesh_dim_idx]
67
- shard_dim = placement.dim
68
-
69
- # Current effective size on this dim (may already be sub-sharded)
70
- if shard_dim in dim_indices:
71
- curr_size = len(dim_indices[shard_dim])
72
- else:
73
- curr_size = target.size()[shard_dim]
74
-
75
- # Compute indices for this level of sharding
76
- if isinstance(placement, _StridedShard):
77
- _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
78
- placement,
79
- curr_size,
80
- num_chunks,
81
- rank_coord,
82
- return_first_offset=False)
83
- new_indices = torch.tensor(offsets, dtype=torch.long)
84
- else:
85
- shard_size, offset = Shard.local_shard_size_and_offset(
86
- curr_size, num_chunks, rank_coord)
87
- new_indices = torch.arange(offset,
88
- offset + shard_size,
89
- dtype=torch.long)
90
-
91
- # Compose with previous indices on this dim
92
- if shard_dim in dim_indices:
93
- dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]
94
- else:
95
- dim_indices[shard_dim] = new_indices
96
-
97
- # Build result tuple
98
- result: list[slice | torch.Tensor] = []
99
- for d in range(len(target.size())):
100
- if d not in dim_indices:
101
- result.append(slice(None))
102
- else:
103
- indices = dim_indices[d]
104
- # Convert contiguous indices to slice for efficiency
105
- if len(indices) > 0:
106
- start = indices[0].item()
107
- expected = torch.arange(start,
108
- start + len(indices),
109
- dtype=torch.long)
110
- if torch.equal(indices, expected):
111
- result.append(slice(start, start + len(indices)))
112
- else:
113
- result.append(indices)
114
- else:
115
- result.append(slice(0, 0))
116
-
117
- return tuple(result)
118
-
119
-
120
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
121
- ProcessGroup]] = dict()
122
-
123
-
124
- def construct_shard_mesh(
125
- placements: tuple[Placement],
126
- mesh: DeviceMesh,
127
- ) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]:
128
- """Construct shard sub-mesh and ProcessGroup for all-to-all communication.
129
-
130
- Given a DTensor's placements and device mesh, extracts the "shard group"
131
- — the set of ranks that together hold all shards of the same replica —
132
- and creates a ProcessGroup for all-to-all among them.
133
-
134
- Steps:
135
- 1. Sort placements: Replicate first, then Shard by (dim, granularity).
136
- 2. Permute the mesh tensor to match the sorted order.
137
- 3. Collapse Replicate dims → list of shard sub-meshes (one per replica).
138
- 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh.
139
-
140
- Example — 8 GPUs, mesh shape (2, 2, 2),
141
- placements ``[Shard(0), Replicate, _StridedShard(0)]``::
142
-
143
- Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)]
144
- Permutation: [1, 2, 0]
145
-
146
- Step 2 — Permute mesh dims by [1, 2, 0]:
147
- Original: Permuted:
148
- [[[0,1],[2,3]], [[[0,2],[1,3]],
149
- [[4,5],[6,7]]] [[4,6],[5,7]]]
150
-
151
- Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes:
152
- sub-mesh 0 = [[0,2],[1,3]] (replica group 0)
153
- sub-mesh 1 = [[4,6],[5,7]] (replica group 1)
154
- shard_placements = (_StridedShard(0), Shard(0))
155
-
156
- Step 4 — Rank 0 → ProcessGroup([0,1,4,5])
157
- Rank 2 → ProcessGroup([2,3,6,7])
158
-
159
- Returns:
160
- ``(shard_mesh, process_group, shard_placements)``
161
- """
162
- my_rank = dist.get_rank()
163
- assert mesh.mesh.device.type == 'cpu'
164
-
165
- # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
- # overhead of dist.new_group(). The standard path below also handles
168
- # subset calls safely via use_local_synchronization=True, but this
169
- # fast path is still beneficial for the common 1D shard case.
170
- if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
- key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
- if key not in _ranks_to_dist_cache:
173
- _ranks_to_dist_cache[key] = (mesh, mesh.get_group())
174
- return (*_ranks_to_dist_cache[key], tuple(placements))
175
-
176
- mesh_tensor = mesh.mesh.clone()
177
-
178
- # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------
179
- # _StridedShard comes BEFORE regular Shard on the same dim so that
180
- # get_slices_of_dtensor applies the outer sharding first, matching
181
- # DTensor's left-to-right (outer-to-inner) composition order.
182
- def _sort_key(item):
183
- index, placement = item
184
- assert not placement.is_partial(), "Partial placement not supported"
185
- if placement.is_replicate():
186
- return (-1, 0, index)
187
- assert _is_shard(placement), f"Unsupported: {type(placement)}"
188
- split = (-1 / placement.split_factor if isinstance(
189
- placement, _StridedShard) else 0)
190
- return (placement.dim, split, index)
191
-
192
- indexed = sorted(enumerate(placements), key=_sort_key)
193
- perm, sorted_placements = zip(*indexed)
194
-
195
- # -- Step 2: Permute mesh to match sorted placement order. --------------
196
- sorted_mesh = mesh_tensor.permute(perm)
197
-
198
- # -- Step 3: Collapse replicate dims → list of shard sub-meshes. --------
199
- # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4)
200
- num_rep = sum(1 for p in sorted_placements if p.is_replicate())
201
- if num_rep > 0:
202
- if num_rep > 1:
203
- sorted_mesh = sorted_mesh.flatten(0, num_rep - 1)
204
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
205
- else:
206
- shard_meshes = [sorted_mesh]
207
- shard_placements = sorted_placements[num_rep:]
208
- assert len(shard_placements) == len(set(shard_placements))
209
-
210
- # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
- # Each rank only creates the group it belongs to, using
212
- # use_local_synchronization=True so that only group members need to
213
- # coordinate. This avoids deadlocks when different PP stages call
214
- # construct_shard_mesh for different parameters.
215
- def _cache_key(t: torch.Tensor) -> tuple:
216
- return (*t.shape, *t.flatten().tolist())
217
-
218
- my_key = None
219
- for sm in shard_meshes:
220
- if (my_rank == sm).any().item():
221
- key = _cache_key(sm)
222
- assert my_key is None, "Rank appears in multiple shard groups"
223
- my_key = key
224
- if key not in _ranks_to_dist_cache:
225
- pg = dist.new_group(sm.flatten().tolist(),
226
- use_local_synchronization=True)
227
- _ranks_to_dist_cache[key] = (
228
- DeviceMesh(device_type="cuda", mesh=sm),
229
- pg,
230
- )
231
-
232
- return (*_ranks_to_dist_cache[my_key], shard_placements)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,122 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- restore_value=['y'],
47
- )
48
- @triton.jit
49
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
50
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
51
- GROUP_SIZE_M: tl.constexpr):
52
- """
53
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
54
- The code is a simple adaptation from the triton `matmul` tutorial:
55
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
56
- """
57
- pid = tl.program_id(axis=0)
58
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
60
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
61
- group_id = pid // num_pid_in_group
62
- first_pid_m = group_id * GROUP_SIZE_M
63
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
64
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
65
- pid_n = (pid % num_pid_in_group) // group_size_m
66
- if pid_m > pid_n:
67
- return
68
-
69
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
71
- offs_k = tl.arange(0, BLOCK_SIZE_K)
72
- # we use a & b ptrs to denote different rows of x.
73
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
75
-
76
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
77
-
78
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
79
- a = tl.load(a_ptrs,
80
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
81
- other=0.0)
82
- b = tl.load(b_ptrs,
83
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
84
- other=0.0)
85
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
86
- a_ptrs += BLOCK_SIZE_K * stride_xk
87
- b_ptrs += BLOCK_SIZE_K * stride_xk
88
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
89
- # https://github.com/triton-lang/triton/issues/2252
90
- c = accumulator.to(x.dtype.element_ty)
91
-
92
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
94
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
95
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
96
- tl.store(c_ptrs, c, mask=c_mask)
97
-
98
- # transpose and copy
99
- if pid_m < pid_n:
100
- ct_ptrs = y + stride_ym * offs_cn[:,
101
- None] + stride_yn * offs_cm[None, :]
102
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
103
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
-
105
-
106
- @torch.library.custom_op("muon::matmul_transpose_assign",
107
- mutates_args=("d_out", ))
108
- def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
- """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
110
- d_in = d_in.contiguous()
111
- M, K = d_in.shape
112
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
113
- M, META['BLOCK_SIZE_M']), )
114
- with torch.cuda.device(d_in.device.index):
115
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
- d_out.stride(0), d_out.stride(1))
117
-
118
-
119
- @matmul_transpose_assign.register_fake
120
- def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
- """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/metadata.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "python-depends": []
3
- }
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/muon.py DELETED
@@ -1,1068 +0,0 @@
1
- import logging
2
- import types
3
- from collections import defaultdict
4
- from typing import Any
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from torch.distributed.tensor import DTensor, Replicate, Shard
9
- from torch.profiler import record_function
10
-
11
- from .adamw import _placement_cache, _tensor_cache, step_adamw
12
- from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
- get_default_muon_param_groups, is_expert_param, update_p)
15
- from .cpu_offload import CPUOffloadPool
16
- from .distributed.utils import (_is_shard, construct_shard_mesh,
17
- get_slices_of_dtensor)
18
- from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
- _zeropower_via_newtonschulz5,
20
- zeropower_via_newtonschulz5,
21
- zeropower_via_newtonschulz5_batched)
22
- from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
- from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def _expand_expert_params(names, params, expert_keys):
29
- """Expand expert params by splitting on dim 0 (expert dimension).
30
-
31
- Params whose name matches any key in ``expert_keys`` are treated as
32
- expert-parallel tensors. Their outermost dimension is the expert
33
- dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
34
- ``nn.Parameter`` views so that in-place updates propagate back to
35
- the original storage.
36
-
37
- Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
38
- if they are expert params, their key must be added to ``expert_keys``.
39
-
40
- The grad must already be set on each expert param (e.g. after momentum).
41
-
42
- For DTensor expert params, placements that shard on dim 0 (expert dim)
43
- are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
44
- preserved: each 2D slice is wrapped as a DTensor on the corresponding
45
- submesh so the parallel pipeline handles the TP communication.
46
- """
47
- expanded_names = []
48
- expanded_params = []
49
-
50
- for n, p in zip(names, params):
51
- is_expert = is_expert_param(n, expert_keys)
52
- is_dtensor = isinstance(p.data, DTensor)
53
-
54
- if is_expert:
55
- if is_dtensor:
56
- logger.debug(
57
- "[expand_expert] %s: expert DTensor, shape=%s, "
58
- "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
- p.placements, p.device_mesh.mesh_dim_names,
60
- p.to_local().shape)
61
- else:
62
- logger.debug(
63
- "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
- p.data.shape)
65
-
66
- if not is_expert:
67
- assert p.data.ndim <= 2, (
68
- f"Param {n} has ndim={p.data.ndim} but does not match "
69
- f"expert_keys={expert_keys}. If this is an expert param, "
70
- f"add its key to expert_keys.")
71
- expanded_names.append(n)
72
- expanded_params.append(p)
73
- continue
74
-
75
- g = p.grad
76
- assert g is not None, (
77
- f"Expert param {n} must have grad set before expansion")
78
-
79
- tp_mesh = None
80
- tp_placements_2d = None
81
-
82
- if is_dtensor:
83
- local_data = p.to_local()
84
- local_grad = g.to_local() if isinstance(g, DTensor) else g
85
-
86
- # Find non-dim-0 shard placements (e.g. TP sharding).
87
- # After splitting on dim 0, Shard(k) becomes Shard(k-1).
88
- tp_dim_indices = []
89
- tp_placements_2d = []
90
- for i, pl in enumerate(p.placements):
91
- if _is_shard(pl) and pl.dim != 0:
92
- tp_dim_indices.append(i)
93
- tp_placements_2d.append(Shard(pl.dim - 1))
94
-
95
- if tp_dim_indices:
96
- tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
97
- for i in tp_dim_indices)
98
- if len(tp_dim_names) == 1:
99
- tp_mesh = p.device_mesh[tp_dim_names[0]]
100
- else:
101
- tp_mesh = p.device_mesh[tp_dim_names]
102
- else:
103
- local_data = p.data
104
- local_grad = g
105
-
106
- # Expand: split dim 0, reshape each slice to 2D.
107
- num_local_experts = local_data.shape[0]
108
- for i in range(num_local_experts):
109
- slice_data = local_data[i]
110
- slice_grad = local_grad[i]
111
-
112
- if tp_mesh is not None:
113
- # Wrap as DTensor on TP submesh so the pipeline handles
114
- # TP communication (gather/scatter across TP ranks).
115
- dt_data = DTensor.from_local(slice_data,
116
- device_mesh=tp_mesh,
117
- placements=tp_placements_2d)
118
- dt_grad = DTensor.from_local(slice_grad,
119
- device_mesh=tp_mesh,
120
- placements=tp_placements_2d)
121
- expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
122
- expert_param.grad = dt_grad
123
- else:
124
- expert_param = torch.nn.Parameter(slice_data,
125
- requires_grad=False)
126
- expert_param.grad = slice_grad
127
-
128
- expanded_names.append(f"{n}[{i}]")
129
- expanded_params.append(expert_param)
130
-
131
- p.grad = None # allow expert grad storage to be freed after pipeline
132
-
133
- return expanded_names, expanded_params
134
-
135
-
136
- class Muon(torch.optim.Optimizer):
137
- """
138
- Muon - MomentUm Orthogonalized by Newton-schulz
139
-
140
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
141
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
142
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
143
- the advantage that it can be stably run in bfloat16 on the GPU.
144
-
145
- Some warnings:
146
- - We believe this optimizer is unlikely to work well for training with small batch size.
147
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
148
-
149
- Arguments:
150
- model: The model to be optimized by Muon.
151
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
152
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
153
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
154
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
155
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
156
- weight_decay: The weight decay for Muon and AdamW.
157
- Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
158
- adamw_lr: The learning rate for the internal AdamW.
159
- adamw_betas: The betas for the internal AdamW.
160
- adamw_eps: The epsilon for the internal AdamW.
161
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
162
- debug: Whether to print debug information.
163
- clip_info : Configuration for QK clipping. Expected keys:
164
- - "q_indices" (list[int]): Indices of query heads to consider.
165
- - "k_indices" (list[int]): Indices of key heads to consider.
166
- - "head_dim" (int): Dimensionality of each attention head.
167
- - "threshold" (float): Threshold value; heads whose QK logits exceed
168
- this value will be scaled down.
169
- Default is:
170
- {
171
- "q_indices": [],
172
- "k_indices": [],
173
- "head_dim": 128,
174
- "threshold": 100
175
- }
176
- warmup_step : How many all2all gather, compute operations are launched in advance
177
- before the corresponding all2all scatter steps begin.
178
- A higher warmup_step increases memory usage but can improve
179
- performance by overlapping communication.
180
- Parallel muon only.
181
- chunk_size : Batch size of parameters to process in each
182
- all2all gather/compute/scatter step.
183
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
- For testing purpose only.
186
- expert_keys: List of strings to identify expert-parallel parameters.
187
- If any key appears in a parameter's name, its outermost
188
- dimension is treated as the expert dimension and expanded
189
- into per-expert 2D params for Muon. For example,
190
- ``expert_keys=["experts"]`` matches any param whose name
191
- contains "experts". 3D+ params not matched by any key
192
- will raise an error.
193
- """
194
-
195
- def __init__(self,
196
- params,
197
- lr=1e-3,
198
- momentum=0.95,
199
- nesterov=True,
200
- ns_steps=5,
201
- weight_decay=0.1,
202
- adamw_betas=(0.9, 0.95),
203
- adamw_eps=1e-8,
204
- none_grad=True,
205
- debug=False,
206
- clip_config=None,
207
- warmup_step=5,
208
- chunk_size=-1,
209
- use_distributed_muon=False,
210
- expert_keys=None):
211
- defaults = dict(
212
- lr=lr,
213
- weight_decay=weight_decay,
214
- momentum=momentum,
215
- nesterov=nesterov,
216
- ns_steps=ns_steps,
217
- adamw_betas=adamw_betas,
218
- adamw_eps=adamw_eps,
219
- none_grad=none_grad,
220
- use_muon=True,
221
- )
222
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
223
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
224
-
225
- if isinstance(params, types.GeneratorType):
226
- raise ValueError(error_message.format(idx=0) + instruction_code)
227
- for _idx, param_group in enumerate(params):
228
- if param_group.get("use_muon", None) is None:
229
- raise ValueError(
230
- error_message.format(idx=_idx) + instruction_code)
231
- super().__init__(params, defaults)
232
-
233
- self.debug = debug
234
- self.clip_config = clip_config if clip_config is not None else {
235
- "q_indices": [],
236
- "k_indices": [],
237
- "head_dim": 128,
238
- "threshold": 100,
239
- }
240
- self.warmup_step = warmup_step
241
- self.chunk_size = chunk_size
242
- self.use_distributed_muon = use_distributed_muon
243
- self.expert_keys = expert_keys
244
- self.cpu_offload = False
245
- self._cpu_offload_pool: CPUOffloadPool | None = None
246
- self._offload_initialized = False
247
- self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
- self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
249
-
250
- def _calc_flops(self, G, steps):
251
- assert len(G.shape) == 2
252
- M, N = G.shape
253
- if M > N:
254
- M, N = N, M
255
-
256
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
257
-
258
- def get_shard_mesh(self, p):
259
- """
260
- Get the shard mesh for a parameter p on the given rank.
261
- """
262
- assert isinstance(
263
- p, DTensor), "Parallel Muon only supports DTensor parameters."
264
-
265
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
266
- p.placements, p.device_mesh)
267
-
268
- return shard_mesh, shard_pg, shard_placements
269
-
270
- def init_state_and_assign_params(self, names, params, group, qk_logits):
271
- param_to_state = {}
272
- param_to_flops = {}
273
-
274
- total_flops = 0
275
- for p in params:
276
- g = p.grad
277
- if g is None:
278
- continue
279
- assert g.ndim == 2, "Muon only supports 2D parameters."
280
-
281
- flops = self._calc_flops(g, group["ns_steps"])
282
- param_to_flops[id(p)] = flops
283
- total_flops += flops
284
-
285
- if self.debug:
286
- logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
287
- total_flops / 1e12)
288
-
289
- paired = list(zip(names, params))
290
-
291
- paired_sorted = sorted(paired,
292
- key=lambda x: param_to_flops[id(x[1])],
293
- reverse=True)
294
-
295
- names_sorted, params_sorted = zip(*paired_sorted)
296
- ordered_names = list(names_sorted)
297
- ordered_params = list(params_sorted)
298
-
299
- round_robin = 0
300
- mesh = ordered_params[0].device_mesh
301
- placements = ordered_params[0].placements
302
-
303
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
304
- ordered_params[0])
305
- shard_mesh_flattened = shard_mesh.mesh.flatten()
306
- num_ranks = dist.get_world_size(group=shard_pg)
307
-
308
- for n, p in zip(ordered_names, ordered_params):
309
- if mesh != p.device_mesh:
310
- raise ValueError("All parameters must be on the same mesh.")
311
- if placements != p.placements:
312
- raise ValueError("All parameters must have same placements.")
313
-
314
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
315
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
316
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
317
-
318
- # Precompute per-rank indices and numels for all-to-all.
319
- rank_indices: dict[int, tuple] = {}
320
- rank_numels: dict[int, int] = {}
321
- for r in range(num_ranks):
322
- indices = get_slices_of_dtensor(p, r, shard_mesh,
323
- shard_placements)
324
- rank_indices[r] = indices
325
- numel = 1
326
- for idx, dim_size in zip(indices, p.shape):
327
- if isinstance(idx, slice):
328
- start, stop, step = idx.indices(dim_size)
329
- numel *= max(0, (stop - start + (step - 1)) // step)
330
- else:
331
- numel *= len(idx)
332
- rank_numels[r] = numel
333
-
334
- param_to_state[id(p)] = _muon_state(
335
- worker_rank=worker_rank,
336
- process_group=shard_pg,
337
- rank_indices=rank_indices,
338
- rank_numels=rank_numels,
339
- name=n,
340
- qk_clip_state=qk_clip_state,
341
- )
342
-
343
- return param_to_state, ordered_params
344
-
345
- def base(self, names, params, group, lr, weight_decay, qk_logits):
346
- # Momentum is already applied by _step_muon before this method.
347
- for n, p in zip(names, params):
348
- g = p.grad
349
- if g is None:
350
- continue
351
-
352
- u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
353
- steps=group["ns_steps"])
354
-
355
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
356
- update_p(p, u, lr, adjusted_lr, weight_decay)
357
-
358
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
359
-
360
- scales_full = compute_scales(
361
- p, qk_clip_state) if qk_clip_state is not None else None
362
- if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state)
364
-
365
- def distributed_muon(
366
- self,
367
- names: list[str],
368
- params: list[torch.nn.Parameter],
369
- group: dict[str, Any],
370
- lr: float,
371
- weight_decay: float,
372
- qk_logits: list[torch.Tensor | DTensor] | None,
373
- ):
374
- """Batched Distributed Muon — for testing/correctness verification only.
375
-
376
- Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
377
- the full grad, then slices back to local shards. This is simpler but
378
- slower than the parallel pipeline (all2all) path, so it serves as a
379
- reference implementation for verifying correctness.
380
- """
381
- with record_function("distributed_muon"):
382
- # Momentum is already applied by _step_muon before this method.
383
- ns_steps = group["ns_steps"]
384
-
385
- # Separate plain tensors (no communication) from DTensors.
386
- plain_names, plain_params = [], []
387
- dtensor_names, dtensor_params = [], []
388
- for n, p in zip(names, params):
389
- if p.grad is None:
390
- continue
391
- if isinstance(p.data, DTensor):
392
- dtensor_names.append(n)
393
- dtensor_params.append(p)
394
- else:
395
- plain_names.append(n)
396
- plain_params.append(p)
397
-
398
- # Process plain tensors per-param (no communication).
399
- for n, p in zip(plain_names, plain_params):
400
- u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
401
- steps=ns_steps)
402
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
403
- update_p(p, u, lr, adjusted_lr, weight_decay)
404
-
405
- qk_clip_state = get_qk_clip_info(self.clip_config, n,
406
- qk_logits)
407
- scales_full = compute_scales(
408
- p, qk_clip_state) if qk_clip_state is not None else None
409
- if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state)
411
-
412
- if not dtensor_params:
413
- return
414
-
415
- # Group DTensors by (placements, mesh) for batched all-gather.
416
- placement_groups: dict[tuple,
417
- tuple[list,
418
- list]] = defaultdict(lambda: ([], []))
419
- for n, p in zip(dtensor_names, dtensor_params):
420
- key = (p.placements, p.device_mesh)
421
- placement_groups[key][0].append(n)
422
- placement_groups[key][1].append(p)
423
-
424
- logger.info(
425
- "distributed_muon: %d placement groups, %d total dtensors",
426
- len(placement_groups), len(dtensor_params))
427
-
428
- for (placements, mesh), (grp_names,
429
- grp_params) in placement_groups.items():
430
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
431
- placements, mesh)
432
- rank = dist.get_rank(shard_pg)
433
- world_size = dist.get_world_size(shard_pg)
434
-
435
- logger.info(" group: %d params, placements=%s, world_size=%d",
436
- len(grp_params), placements, world_size)
437
-
438
- # Separate params that can be batched (all shard dims evenly
439
- # divisible) from those needing per-param full_tensor
440
- # (e.g. MoE gate weights with fewer rows than shard ranks).
441
- # all_gather_into_tensor requires equal buffer sizes across
442
- # ranks, so uneven splits must use DTensor full_tensor().
443
- batch_names, batch_params = [], []
444
- single_names, single_params = [], []
445
- for n, p in zip(grp_names, grp_params):
446
- even = all(p.shape[pl.dim] %
447
- shard_mesh.mesh.shape[dim_idx] == 0
448
- for dim_idx, pl in enumerate(shard_placements))
449
- if even:
450
- batch_names.append(n)
451
- batch_params.append(p)
452
- else:
453
- single_names.append(n)
454
- single_params.append(p)
455
-
456
- # Process uneven-split params per-param via full_tensor().
457
- for n, p in zip(single_names, single_params):
458
- with record_function("distributed_muon::newton_schulz"):
459
- g_full = p.grad.full_tensor().to(COMM_DTYPE)
460
- u_full = _zeropower_via_newtonschulz5(g_full,
461
- steps=ns_steps)
462
- del g_full
463
- with record_function("distributed_muon::update"):
464
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
465
- p._local_tensor.mul_(1 - lr * weight_decay)
466
- local_indices = get_slices_of_dtensor(
467
- p, rank, shard_mesh, shard_placements)
468
- u_local = u_full[local_indices]
469
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
470
- del u_full
471
-
472
- qk_clip_state = get_qk_clip_info(
473
- self.clip_config, n, qk_logits)
474
- scales_full = compute_scales(
475
- p, qk_clip_state
476
- ) if qk_clip_state is not None else None
477
- if scales_full is not None:
478
- ratio = p.shape[0] // scales_full.shape[0]
479
- idx0 = local_indices[0]
480
- if isinstance(idx0, slice):
481
- start = idx0.start or 0
482
- idx0 = torch.arange(start,
483
- idx0.stop,
484
- device=scales_full.device)
485
- row_scales = scales_full[idx0 // ratio]
486
- p._local_tensor.mul_(row_scales.view(-1, 1))
487
-
488
- if not batch_params:
489
- continue
490
-
491
- logger.info(" batched=%d, single=%d", len(batch_params),
492
- len(single_params))
493
-
494
- # Concat all local grad shards into a single flat buffer.
495
- with record_function("distributed_muon::gather"):
496
- grad_locals = [
497
- p.grad.to_local().to(COMM_DTYPE).flatten()
498
- for p in batch_params
499
- ]
500
- numels = [g.numel() for g in grad_locals]
501
- grad_concat = torch.cat(grad_locals)
502
- del grad_locals
503
-
504
- # Single all-gather (replaces N separate full_tensor).
505
- grad_gathered = torch.empty(
506
- grad_concat.numel() * world_size,
507
- dtype=COMM_DTYPE,
508
- device="cuda",
509
- )
510
- dist.all_gather_into_tensor(grad_gathered,
511
- grad_concat,
512
- group=shard_pg)
513
-
514
- total_numel = grad_concat.numel()
515
- del grad_concat
516
-
517
- # Precompute per-param offsets within the concat buffer.
518
- offsets = []
519
- off = 0
520
- for ne in numels:
521
- offsets.append(off)
522
- off += ne
523
-
524
- # Per-param: reconstruct full grad → NS → local update.
525
- for i, (n, p) in enumerate(zip(batch_names, batch_params)):
526
- with record_function("distributed_muon::newton_schulz"):
527
- g_full = torch.empty(p.shape,
528
- dtype=COMM_DTYPE,
529
- device="cuda")
530
- for r in range(world_size):
531
- r_start = r * total_numel + offsets[i]
532
- shard = grad_gathered[r_start:r_start + numels[i]]
533
- indices = get_slices_of_dtensor(
534
- p, r, shard_mesh, shard_placements)
535
- g_full[indices] = shard.reshape(
536
- g_full[indices].shape)
537
-
538
- u_full = _zeropower_via_newtonschulz5(g_full,
539
- steps=ns_steps)
540
- del g_full
541
-
542
- with record_function("distributed_muon::update"):
543
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
544
- p._local_tensor.mul_(1 - lr * weight_decay)
545
- local_indices = get_slices_of_dtensor(
546
- p, rank, shard_mesh, shard_placements)
547
- u_local = u_full[local_indices]
548
- p._local_tensor.add_(u_local, alpha=-adjusted_lr)
549
- del u_full
550
-
551
- qk_clip_state = get_qk_clip_info(
552
- self.clip_config, n, qk_logits)
553
- scales_full = compute_scales(
554
- p, qk_clip_state
555
- ) if qk_clip_state is not None else None
556
- if scales_full is not None:
557
- ratio = p.shape[0] // scales_full.shape[0]
558
- idx0 = local_indices[0]
559
- if isinstance(idx0, slice):
560
- start = idx0.start or 0
561
- idx0 = torch.arange(start,
562
- idx0.stop,
563
- device=scales_full.device)
564
- row_scales = scales_full[idx0 // ratio]
565
- p._local_tensor.mul_(row_scales.view(-1, 1))
566
-
567
- def _setup_parallel(self, names, params, group, qk_logits):
568
- """Compute (or retrieve cached) parallel pipeline metadata.
569
-
570
- Returns:
571
- (ordered_params, param_to_state, rank, chunk_size)
572
- """
573
- cache_key = tuple(names)
574
-
575
- if cache_key not in self._parallel_cache:
576
- # First call: compute metadata and populate cache.
577
- param_to_state, ordered_params = self.init_state_and_assign_params(
578
- names, params, group, qk_logits)
579
-
580
- shard_pg = param_to_state[id(ordered_params[0])].process_group
581
- rank = dist.get_rank(group=shard_pg)
582
-
583
- if self.chunk_size == -1:
584
- shard_ranks = dist.get_world_size(shard_pg)
585
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
586
- elif self.chunk_size > 0:
587
- chunk_size = self.chunk_size
588
- else:
589
- raise ValueError(
590
- "chunk_size must be -1 or a positive integer.")
591
-
592
- ordered_names = [
593
- param_to_state[id(p)].name for p in ordered_params
594
- ]
595
- name_to_state = {
596
- param_to_state[id(p)].name: param_to_state[id(p)]
597
- for p in ordered_params
598
- }
599
- self._parallel_cache[cache_key] = {
600
- 'ordered_names': ordered_names,
601
- 'name_to_state': name_to_state,
602
- 'rank': rank,
603
- 'chunk_size': chunk_size,
604
- }
605
- else:
606
- # Cached path: rebuild param_to_state with current id(p) keys.
607
- cache = self._parallel_cache[cache_key]
608
- rank = cache['rank']
609
- chunk_size = cache['chunk_size']
610
-
611
- name_to_param = dict(zip(names, params))
612
- ordered_params = [name_to_param[n] for n in cache['ordered_names']]
613
-
614
- param_to_state = {}
615
- for p, n in zip(ordered_params, cache['ordered_names']):
616
- cached_state = cache['name_to_state'][n]
617
- param_to_state[id(p)] = _muon_state(
618
- worker_rank=cached_state.worker_rank,
619
- process_group=cached_state.process_group,
620
- rank_indices=cached_state.rank_indices,
621
- rank_numels=cached_state.rank_numels,
622
- name=n,
623
- qk_clip_state=get_qk_clip_info(self.clip_config, n,
624
- qk_logits),
625
- )
626
-
627
- return ordered_params, param_to_state, rank, chunk_size
628
-
629
- def parallel(self,
630
- names,
631
- params,
632
- group,
633
- lr,
634
- weight_decay,
635
- qk_logits,
636
- prelaunch_gather=None):
637
- """
638
- Perform a parallel optimization step using Muon.
639
-
640
- Parameters are chunked and each chunk is processed by a
641
- :func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
642
- interleaves multiple chunks so that communication and computation
643
- overlap across chunks (the same overlap previously achieved by the
644
- warmup + main-loop index scheduling).
645
-
646
- If ``prelaunch_gather`` is provided, it is passed to the first
647
- chunk's generator to skip re-launching the already in-flight
648
- A2A gather.
649
- """
650
-
651
- # Momentum is already applied by _step_muon before this method.
652
-
653
- ordered_params, param_to_state, rank, chunk_size = (
654
- self._setup_parallel(names, params, group, qk_logits))
655
-
656
- def pipelines():
657
- first = True
658
- for start in range(0, len(ordered_params), chunk_size):
659
- chunk = ordered_params[start:start + chunk_size]
660
- if chunk:
661
- kwargs = dict(
662
- params=chunk,
663
- param_to_state=param_to_state,
664
- rank=rank,
665
- ns_steps=group["ns_steps"],
666
- lr=lr,
667
- weight_decay=weight_decay,
668
- none_grad=group["none_grad"],
669
- )
670
- if first and prelaunch_gather is not None:
671
- kwargs['prelaunch_gather'] = prelaunch_gather
672
- first = False
673
- yield muon_chunk_pipeline(**kwargs)
674
-
675
- with record_function("muon::pipeline"):
676
- run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
677
-
678
- def _step_muon(self, group, qk_logits=None):
679
- params = group["params"]
680
- lr = group["lr"]
681
- weight_decay = group["weight_decay"]
682
- momentum = group["momentum"]
683
- names = group["names"]
684
-
685
- # Apply momentum to all params before routing/expansion.
686
- # Batched using _foreach_* ops (compiled, fullgraph=True).
687
- with record_function("muon::momentum"):
688
- active_params = [p for p in params if p.grad is not None]
689
- if active_params:
690
- # Ensure momentum buffers exist (avoid zeros_like when already present).
691
- for p in active_params:
692
- if "momentum_buffer" not in self.state[p]:
693
- self.state[p]["momentum_buffer"] = torch.zeros_like(
694
- p.grad)
695
-
696
- # Extract local tensors for compiled batch function.
697
- local_grads = [
698
- p.grad._local_tensor
699
- if isinstance(p.grad, DTensor) else p.grad
700
- for p in active_params
701
- ]
702
- local_bufs = [
703
- self.state[p]["momentum_buffer"]._local_tensor
704
- if isinstance(self.state[p]["momentum_buffer"], DTensor)
705
- else self.state[p]["momentum_buffer"]
706
- for p in active_params
707
- ]
708
-
709
- # Wrap momentum as tensor for torch.compile.
710
- batch_pre_ortho(local_grads, local_bufs,
711
- torch.tensor(momentum), group["nesterov"])
712
-
713
- # For non-nesterov, the result is the momentum buffer.
714
- if not group["nesterov"]:
715
- for p in active_params:
716
- p.grad = self.state[p]["momentum_buffer"]
717
-
718
- # Identify batched experts for deferred NS.
719
- # Detection is cheap (condition checks only); actual NS compute is
720
- # deferred so it can overlap with the first chunk's A2A gather.
721
- deferred_expert_work = []
722
- if self.expert_keys:
723
- batched_expert_indices = []
724
- for i, (n, p) in enumerate(zip(names, params)):
725
- if not (is_expert_param(n, self.expert_keys)
726
- and p.grad is not None):
727
- continue
728
- # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
729
- if isinstance(p.data, DTensor):
730
- has_tp = any(
731
- _is_shard(pl) and pl.dim != 0 for pl in p.placements)
732
- if has_tp:
733
- continue
734
- batched_expert_indices.append(i)
735
-
736
- if batched_expert_indices:
737
- # Save refs for deferred NS; free grads from param list.
738
- for i in batched_expert_indices:
739
- p = params[i]
740
- g = p.grad
741
- local_g = (g._local_tensor
742
- if isinstance(g, DTensor) else g)
743
- local_data = (p.data._local_tensor if isinstance(
744
- p.data, DTensor) else p.data)
745
- deferred_expert_work.append((local_data, local_g))
746
- p.grad = None
747
-
748
- # Remove batched experts from lists before expansion.
749
- keep = sorted(
750
- set(range(len(params))) - set(batched_expert_indices))
751
- names = [names[i] for i in keep]
752
- params = [params[i] for i in keep]
753
-
754
- def _run_deferred_expert_ns():
755
- """Execute deferred batched expert NS."""
756
- if not deferred_expert_work:
757
- return
758
- with record_function("muon::batched_expert_ns"):
759
- ns_steps = group["ns_steps"]
760
- for local_data, local_g in deferred_expert_work:
761
- u = zeropower_via_newtonschulz5_batched(
762
- local_g.to(COMM_DTYPE), steps=ns_steps)
763
- adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
764
- local_data.mul_(1 - lr * weight_decay)
765
- local_data.add_(u, alpha=-adjusted_lr)
766
-
767
- # Expand expert params by splitting on dim 0.
768
- logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
769
- len(params), self.expert_keys)
770
- if self.expert_keys:
771
- cache_key = tuple(id(p) for p in params)
772
- cache = self._expert_expand_cache.get(cache_key)
773
-
774
- if cache is None:
775
- # Cold path: full expansion + build cache metadata.
776
- exp_names, exp_params = _expand_expert_params(
777
- names, params, self.expert_keys)
778
-
779
- # Build per-expert-group info for hot-path grad updates.
780
- grad_info = []
781
- exp_idx = 0
782
- for orig_idx, (n, p) in enumerate(zip(names, params)):
783
- if not is_expert_param(n, self.expert_keys):
784
- exp_idx += 1
785
- continue
786
-
787
- is_dt = isinstance(p.data, DTensor)
788
- num_experts = (p.to_local() if is_dt else p.data).shape[0]
789
-
790
- # Detect TP mesh from the first expanded expert param.
791
- tp_mesh = None
792
- tp_pls = None
793
- sample = exp_params[exp_idx]
794
- if isinstance(sample.data, DTensor):
795
- tp_mesh = sample.data.device_mesh
796
- tp_pls = list(sample.data.placements)
797
-
798
- grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
799
- tp_mesh, tp_pls))
800
- exp_idx += num_experts
801
-
802
- self._expert_expand_cache[cache_key] = {
803
- 'names': exp_names,
804
- 'params': exp_params,
805
- 'grad_info': grad_info,
806
- }
807
- names, params = exp_names, exp_params
808
- else:
809
- # Hot path: reuse cached params, only update expert grads.
810
- for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
811
- tp_pls) in cache['grad_info']:
812
- p = params[orig_idx]
813
- g = p.grad
814
- local_grad = (g.to_local()
815
- if is_dt and isinstance(g, DTensor) else g)
816
- for i in range(num_experts):
817
- expert_p = cache['params'][exp_start + i]
818
- sg = local_grad[i]
819
- if tp_mesh is not None:
820
- expert_p.grad = DTensor.from_local(
821
- sg, device_mesh=tp_mesh, placements=tp_pls)
822
- else:
823
- expert_p.grad = sg
824
- p.grad = None
825
-
826
- names = cache['names']
827
- params = cache['params']
828
- else:
829
- names, params = _expand_expert_params(names, params,
830
- self.expert_keys)
831
- logger.debug("[_step_muon] after expand: %d params", len(params))
832
-
833
- param_dtensors = []
834
- name_dtensors = []
835
-
836
- param_tensors = []
837
- name_tensors = []
838
-
839
- # distributed_muon is a reference implementation for testing only.
840
- # The parallel pipeline (all2all) path below is the production path.
841
- if self.use_distributed_muon:
842
- _run_deferred_expert_ns()
843
- self.distributed_muon(names=names,
844
- params=params,
845
- group=group,
846
- lr=lr,
847
- weight_decay=weight_decay,
848
- qk_logits=qk_logits)
849
- return
850
-
851
- for n, p in zip(names, params):
852
- if p is None or p.grad is None:
853
- continue
854
- if isinstance(p.data, DTensor):
855
- if all(
856
- isinstance(placement, Replicate)
857
- for placement in p.placements):
858
- logger.debug(
859
- "[route] %s → base (DTensor all-Replicate), "
860
- "shape=%s, placements=%s", n, p.shape, p.placements)
861
- param_tensors.append(p)
862
- name_tensors.append(n)
863
- else:
864
- logger.debug(
865
- "[route] %s → parallel (DTensor), shape=%s, "
866
- "placements=%s, mesh=%s", n, p.shape, p.placements,
867
- p.device_mesh.mesh_dim_names)
868
- param_dtensors.append(p)
869
- name_dtensors.append(n)
870
- elif isinstance(p.data, torch.Tensor):
871
- logger.debug("[route] %s → base (plain tensor), shape=%s", n,
872
- p.data.shape)
873
- param_tensors.append(p)
874
- name_tensors.append(n)
875
- else:
876
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
877
-
878
- logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
879
- f"{len(param_tensors)} Tensors → base")
880
-
881
- def group_dtensors(dtensors, names):
882
- # To support different placements, we group parameters by placements
883
- # and run parallel Muon on each group.
884
-
885
- placement_to_params = defaultdict(lambda: ([], []))
886
-
887
- assert len(dtensors) == len(names)
888
- for p, n in zip(dtensors, names):
889
- placement_to_params[tuple([p.placements,
890
- p.device_mesh])][0].append(n)
891
- placement_to_params[tuple([p.placements,
892
- p.device_mesh])][1].append(p)
893
- return placement_to_params
894
-
895
- if len(param_dtensors) > 0:
896
- if not dist.is_initialized():
897
- raise RuntimeError(
898
- "Parallel Muon requires torch.distributed to be initialized."
899
- )
900
-
901
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
902
-
903
- # Pre-launch the first chunk's A2A gather so that the NCCL
904
- # communication overlaps with the (deferred) batched expert NS
905
- # compute on the default CUDA stream.
906
- prelaunch = None
907
- if deferred_expert_work:
908
- first_names, first_params = next(iter(dtensor_group.values()))
909
- ordered, pts, rnk, csz = self._setup_parallel(
910
- first_names, first_params, group, qk_logits)
911
- first_chunk = ordered[:csz]
912
- if first_chunk:
913
- prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
914
- group["none_grad"])
915
-
916
- _run_deferred_expert_ns()
917
-
918
- first_group = True
919
- for _, (names, params) in dtensor_group.items():
920
- pg = prelaunch if first_group else None
921
- first_group = False
922
- self.parallel(
923
- names,
924
- params,
925
- group,
926
- lr=lr,
927
- weight_decay=weight_decay,
928
- qk_logits=qk_logits,
929
- prelaunch_gather=pg,
930
- )
931
- else:
932
- _run_deferred_expert_ns()
933
-
934
- if len(param_tensors) > 0:
935
- self.base(
936
- name_tensors,
937
- param_tensors,
938
- group,
939
- lr=lr,
940
- weight_decay=weight_decay,
941
- qk_logits=qk_logits,
942
- )
943
-
944
- def _register_states_for_offload(self):
945
- """Register all optimizer state tensors with the CPU offload pool.
946
-
947
- Called once after the first step when states have been lazily created.
948
- Offloads all param states (momentum buffers for Muon, moment1/moment2
949
- for AdamW) to free GPU memory between steps.
950
- """
951
- pool = self._cpu_offload_pool
952
- tracked = 0
953
- for group in self.param_groups:
954
- for p in group["params"]:
955
- if p not in self.state:
956
- continue
957
- state = self.state[p]
958
- if group.get("use_muon", False):
959
- if "momentum_buffer" in state:
960
- pool.track(state["momentum_buffer"])
961
- tracked += 1
962
- else:
963
- if "moment1" in state:
964
- pool.track(state["moment1"])
965
- if "moment2" in state:
966
- pool.track(state["moment2"])
967
- tracked += 1
968
- logger.info("[CPUOffload] Registered %d param states for offload",
969
- tracked)
970
-
971
- @torch.no_grad
972
- def step(self, closure=None, qk_logits=None):
973
- """Perform a single optimization step.
974
-
975
- Args:
976
- closure (Callable, optional): A closure that reevaluates the model
977
- and returns the loss.
978
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
979
- to 1D tensors of shape (num_heads,), representing the maximum
980
- QK logits across all tokens, computed as
981
- (1 / sqrt(head_dim)) * (Q @ K^T).
982
- """
983
- loss = None
984
- if closure is not None:
985
- with torch.enable_grad():
986
- loss = closure()
987
-
988
- # H2D: reload optimizer states from CPU before computation.
989
- if self.cpu_offload and self._offload_initialized:
990
- self._cpu_offload_pool.reload()
991
-
992
- logger.debug("[Muon.step] expert_keys=%s, %d param groups",
993
- self.expert_keys, len(self.param_groups))
994
-
995
- for i, group in enumerate(self.param_groups):
996
- if group["use_muon"]:
997
- logger.debug("[Muon.step] group %d: use_muon=True, %d params",
998
- i, len(group["params"]))
999
- self._step_muon(group, qk_logits=qk_logits)
1000
- else:
1001
- logger.debug(
1002
- "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1003
- i, len(group["params"]))
1004
- step_adamw(self.state, group)
1005
-
1006
- # D2H: offload optimizer states to CPU after computation.
1007
- if self.cpu_offload:
1008
- if not self._offload_initialized:
1009
- if self._cpu_offload_pool is None:
1010
- self._cpu_offload_pool = CPUOffloadPool()
1011
- self._register_states_for_offload()
1012
- self._offload_initialized = True
1013
- self._cpu_offload_pool.offload()
1014
-
1015
- return loss
1016
-
1017
- # ------------------------------------------------------------------
1018
- # CPU offload public helpers
1019
- # ------------------------------------------------------------------
1020
-
1021
- def turn_on_cpu_offload(self):
1022
- """Enable CPU offload for optimizer states."""
1023
- if self.cpu_offload:
1024
- return
1025
- logger.info("[Muon] turn_on_cpu_offload")
1026
- self.cpu_offload = True
1027
- if not self.state:
1028
- return
1029
- self._cpu_offload_pool = CPUOffloadPool()
1030
- self._offload_initialized = False
1031
- self._register_states_for_offload()
1032
- self._offload_initialized = True
1033
- self._cpu_offload_pool.offload()
1034
-
1035
- def turn_off_cpu_offload(self):
1036
- """Disable CPU offload and keep optimizer states resident on GPU."""
1037
- if not self.cpu_offload:
1038
- return
1039
- logger.info("[Muon] turn_off_cpu_offload")
1040
- if self._offload_initialized:
1041
- self._cpu_offload_pool.reload()
1042
- torch.cuda.current_stream().synchronize()
1043
- self._cpu_offload_pool = None
1044
- self._offload_initialized = False
1045
- self.cpu_offload = False
1046
-
1047
- # ------------------------------------------------------------------
1048
- # Checkpoint support for cpu_offload
1049
- # ------------------------------------------------------------------
1050
-
1051
- def state_dict(self) -> dict:
1052
- if self.cpu_offload:
1053
- raise RuntimeError(
1054
- "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
- )
1056
- return super().state_dict()
1057
-
1058
- def load_state_dict(self, state_dict: dict) -> None:
1059
- if self.cpu_offload:
1060
- raise RuntimeError(
1061
- "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
- )
1063
- super().load_state_dict(state_dict)
1064
-
1065
- # Invalidate adamw.py's module-level tensor caches so that
1066
- # the next step rebuilds them with the newly loaded state tensors.
1067
- _placement_cache.clear()
1068
- _tensor_cache.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py DELETED
@@ -1,240 +0,0 @@
1
- from itertools import repeat
2
- from math import inf, sqrt
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from .matmul_transpose_triton import matmul_transpose_assign
8
-
9
- COMM_DTYPE = torch.bfloat16
10
- DEFAULT_CHUNK_SIZE_RATIO = 4
11
-
12
-
13
- def _optimal_quintic(l, u, max_iter=1000):
14
- """
15
- Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
- to the constant function x -> 1 over the interval [l, u].
17
-
18
- Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
- approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
- two interior equioscillation nodes q, r until convergence. Returns the
21
- closed-form equioscillating solution when l ≈ u.
22
-
23
- Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
- (NaN or inf). Raises RuntimeError if convergence is not reached within
25
- max_iter iterations.
26
- """
27
- assert 0 <= l <= u
28
- if 1 - 5e-6 <= l / u:
29
- return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
- q = (3 * l + u) / 4
31
- r = (l + 3 * u) / 4
32
- E = inf
33
- for _ in range(max_iter):
34
- old_E = E
35
- LHS = np.array(
36
- [
37
- [l, l**3, l**5, 1],
38
- [q, q**3, q**5, -1],
39
- [r, r**3, r**5, 1],
40
- [u, u**3, u**5, -1],
41
- ]
42
- )
43
- a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
- if not np.all(np.isfinite([a, b, c, E])):
45
- raise ValueError(
46
- f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
- )
48
- q, r = np.sqrt(
49
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
- )
51
- if not np.all(np.isfinite([q, r])):
52
- raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
53
- if abs(old_E - E) <= 1e-15:
54
- break
55
- else:
56
- raise RuntimeError(
57
- f"_optimal_quintic: did not converge after {max_iter} iterations"
58
- )
59
- return float(a), float(b), float(c)
60
-
61
-
62
- def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
63
- """
64
- Compute the Polar Express coefficient series for `num_iters` quintic iterations.
65
-
66
- Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
67
- compose to map singular values from [l, 1] toward 1. At each step:
68
- 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
69
- prevents near-zero singular values from stalling by raising the effective
70
- lower bound; if it is active (cushion*u > l), the coefficients are
71
- rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
72
- 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
73
- last iteration, providing numerical headroom at the cost of a slightly slower
74
- final convergence step.
75
- 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
76
-
77
- Returns a list of (a, b, c) tuples, one per iteration.
78
-
79
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
80
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
81
- """
82
- u = 1
83
- assert 0 <= l <= u
84
- safety_factor = 1 + safety_factor_eps
85
- coefficients = []
86
- for iter in range(num_iters):
87
- a, b, c = _optimal_quintic(max(l, cushion * u), u)
88
- if cushion * u > l:
89
- pl = a * l + b * l**3 + c * l**5
90
- pu = a * u + b * u**3 + c * u**5
91
- rescaler = 2 / (pl + pu)
92
- a *= rescaler
93
- b *= rescaler
94
- c *= rescaler
95
- if iter < num_iters - 1:
96
- a /= safety_factor
97
- b /= safety_factor**3
98
- c /= safety_factor**5
99
- coefficients.append((a, b, c))
100
- l = a * l + b * l**3 + c * l**5
101
- u = 2 - l
102
- return coefficients
103
-
104
-
105
- # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
106
- # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
107
- # approximant to x->1 over the current singular-value interval, computed once at
108
- # import time and reused across all optimizer steps.
109
- #
110
- # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
111
- # - Former: empirically tuned to maximize slope at zero; did not converge
112
- # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
113
- # of the true polar factor UV^T.
114
- # - Polar Express: analytically optimal per step, adapting to the shrinking
115
- # singular-value interval [l, u] as iterations progress; converges all
116
- # singular values to 1, producing the exact polar factor UV^T.
117
- _coeffs_list = _optimal_composition(
118
- l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
- )
120
-
121
-
122
- # This code is adapted from:
123
- # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
124
- # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
125
- # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
126
- @torch.no_grad()
127
- def _zeropower_via_newtonschulz5(G, steps):
128
- """
129
- Compute the polar factor of G via the Polar Express method.
130
-
131
- Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
132
- are the Polar Express coefficients from `_coeffs_list`. Each step is the
133
- optimal odd quintic approximant to x -> 1 over the current singular-value
134
- interval, minimizing the maximum approximation error (Remez / minimax criterion).
135
- The composition maps singular values from [l, 1] to near 1, producing the
136
- polar factor (orthogonal factor in the polar decomposition G = UP).
137
-
138
- `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
139
- cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
140
-
141
- Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
142
- Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
143
- """
144
- assert len(G.shape) == 2
145
- assert G.dtype == COMM_DTYPE
146
- X = G # no manual typecast
147
-
148
- if G.size(0) > G.size(1):
149
- X = X.T
150
-
151
- X = X / (X.norm() + 1e-7)
152
- hs = _coeffs_list[:steps] + list(
153
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
- )
155
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
- # Perform the NS iterations
158
- for a, b, c in hs:
159
- matmul_transpose_assign(X, buf1)
160
- matmul_transpose_assign(buf1, buf2)
161
- buf1.mul_(b).add_(buf2, alpha=c)
162
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
163
-
164
- if G.size(0) > G.size(1):
165
- X = X.T
166
-
167
- return X
168
-
169
-
170
- @torch.no_grad()
171
- def _zeropower_via_newtonschulz5_batched(G, steps):
172
- """Batched polar factor computation for 3D (E, out, in) tensors.
173
-
174
- Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
175
- ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
176
- processing all E expert matrices in a single batched call.
177
- """
178
- assert len(G.shape) == 3
179
- assert G.dtype == COMM_DTYPE
180
- X = G
181
-
182
- if G.size(1) > G.size(2):
183
- X = X.transpose(-2, -1)
184
-
185
- # Per-expert Frobenius norm.
186
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
-
188
- hs = _coeffs_list[:steps] + list(
189
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
- )
191
- for a, b, c in hs:
192
- buf1 = torch.bmm(X, X.transpose(-2, -1))
193
- buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
194
- buf1.mul_(b).add_(buf2, alpha=c)
195
- X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
196
-
197
- if G.size(1) > G.size(2):
198
- X = X.transpose(-2, -1)
199
-
200
- return X
201
-
202
-
203
- _ns_per_shape: dict[tuple[int, ...], callable] = {}
204
- _use_compile = True
205
-
206
-
207
- def set_ns_compile(enabled: bool):
208
- """Toggle torch.compile for Newton-Schulz iteration."""
209
- global _use_compile
210
- _use_compile = enabled
211
-
212
-
213
- def zeropower_via_newtonschulz5(G, steps=5):
214
- if not _use_compile:
215
- return _zeropower_via_newtonschulz5(G, steps)
216
- key = G.shape
217
- if key not in _ns_per_shape:
218
- _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
219
- options={
220
- "triton.cudagraphs": True,
221
- "shape_padding": False
222
- })
223
- torch.compiler.cudagraph_mark_step_begin()
224
- return _ns_per_shape[key](G, steps).clone()
225
-
226
-
227
- def zeropower_via_newtonschulz5_batched(G, steps=5):
228
- """Compile-cached batched Newton-Schulz for 3D expert tensors."""
229
- if not _use_compile:
230
- return _zeropower_via_newtonschulz5_batched(G, steps)
231
- key = G.shape
232
- if key not in _ns_per_shape:
233
- _ns_per_shape[key] = torch.compile(
234
- _zeropower_via_newtonschulz5_batched,
235
- options={
236
- "triton.cudagraphs": True,
237
- "shape_padding": False
238
- })
239
- torch.compiler.cudagraph_mark_step_begin()
240
- return _ns_per_shape[key](G, steps).clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/pipeline.py DELETED
@@ -1,468 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- import torch
5
- import torch.distributed as dist
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- from .core import _muon_state, adjust_lr_for_muon
10
- from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
- from .qk_clip import compute_scales
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # ======================================================================
16
- # Stage helpers
17
- # ======================================================================
18
-
19
-
20
- def _launch_gather(
21
- params: list[DTensor],
22
- owned_params: list[DTensor],
23
- param_to_state: dict[int, _muon_state],
24
- rank: int,
25
- num_ranks: int,
26
- process_group: dist.ProcessGroup,
27
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
28
- """Allocate gather buffers, build send/recv, and launch async all-to-all.
29
-
30
- Returns:
31
- work: Async operation handle.
32
- recv_buf: Flat receive buffer (needed by ``_complete_gather``).
33
- gathered_grads: ``{id(p): empty_tensor}`` for owned params,
34
- ``None`` for non-owned.
35
- recv_counts: Per-source-rank element counts.
36
- """
37
- # Allocate gathered-grad buffers
38
- gathered_grads: dict[int, torch.Tensor | None] = {}
39
- for p in params:
40
- state = param_to_state[id(p)]
41
- if rank == state.worker_rank:
42
- gathered_grads[id(p)] = torch.empty(p.shape,
43
- dtype=COMM_DTYPE,
44
- device="cuda")
45
- else:
46
- gathered_grads[id(p)] = None
47
-
48
- # Build send buffer – batch grad copies via torch.cat
49
- # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
- send_counts = [0] * num_ranks
51
- for p in params:
52
- state = param_to_state[id(p)]
53
- send_counts[state.worker_rank] += state.rank_numels[rank]
54
-
55
- total_send = sum(send_counts)
56
- if total_send > 0:
57
- # Group grad slices by destination rank in a single pass.
58
- dst_to_grads = [[] for _ in range(num_ranks)]
59
- for p in params:
60
- state = param_to_state[id(p)]
61
- n = state.rank_numels[rank]
62
- if n > 0:
63
- g = p.grad.to_local()
64
- dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
-
66
- # Flatten in dst order and cat once.
67
- all_slices = []
68
- for dst in range(num_ranks):
69
- all_slices.extend(dst_to_grads[dst])
70
- send_buf = torch.cat(all_slices)
71
- if send_buf.dtype != COMM_DTYPE:
72
- send_buf = send_buf.to(COMM_DTYPE)
73
- else:
74
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
-
76
- # Build recv buffer
77
- recv_counts = [0] * num_ranks
78
- for src in range(num_ranks):
79
- total = 0
80
- for p in owned_params:
81
- state = param_to_state[id(p)]
82
- assert state.worker_rank == rank
83
- total += state.rank_numels[src]
84
- recv_counts[src] = total
85
-
86
- recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
87
-
88
- # Launch async all-to-all
89
- logger.debug(f"send_buf size: {send_buf.numel()}, "
90
- f"recv_buf size: {recv_buf.numel()}, "
91
- f"recv_counts: {recv_counts}, "
92
- f"send_counts: {send_counts}, "
93
- f"process_group: {str(process_group)}")
94
- work = dist.all_to_all_single(
95
- recv_buf,
96
- send_buf,
97
- output_split_sizes=recv_counts,
98
- input_split_sizes=send_counts,
99
- group=process_group,
100
- async_op=True,
101
- )
102
-
103
- return work, recv_buf, gathered_grads, recv_counts
104
-
105
-
106
- def _complete_gather(
107
- recv_buf: torch.Tensor,
108
- recv_counts: list[int],
109
- owned_params: list[DTensor],
110
- gathered_grads: dict[int, torch.Tensor | None],
111
- param_to_state: dict[int, _muon_state],
112
- rank: int,
113
- ) -> None:
114
- """Reconstruct gathered grads from the recv buffer (in-place)."""
115
- off = 0
116
- for src in range(len(recv_counts)):
117
- if recv_counts[src] == 0:
118
- continue
119
-
120
- block = recv_counts[src]
121
- inner_off = 0
122
- for p in owned_params:
123
- state = param_to_state[id(p)]
124
- assert state.worker_rank == rank
125
-
126
- indices = state.rank_indices[src]
127
-
128
- shard_view = gathered_grads[id(p)][indices]
129
- n = shard_view.numel()
130
- if n == 0:
131
- continue
132
-
133
- sg = recv_buf.narrow(0, off + inner_off, n)
134
- sg = sg.reshape(shard_view.shape)
135
- gathered_grads[id(p)][indices] = sg
136
-
137
- inner_off += n
138
- assert inner_off == block
139
- off += block
140
-
141
-
142
- def _compute_ns(
143
- owned_params: list[DTensor],
144
- gathered_grads: dict[int, torch.Tensor | None],
145
- ns_steps: int,
146
- ) -> dict[int, torch.Tensor | None]:
147
- """Run Newton-Schulz orthogonalization on owned parameters.
148
-
149
- Returns:
150
- computed_us: ``{id(p): orthogonalized_update}`` for owned params.
151
- """
152
- computed_us: dict[int, torch.Tensor | None] = {}
153
- for p in owned_params:
154
- u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
- gathered_grads[id(p)] = None # free gathered grad
156
- computed_us[id(p)] = u
157
- return computed_us
158
-
159
-
160
- def _launch_scatter(
161
- params: list[DTensor],
162
- owned_params: list[DTensor],
163
- param_to_state: dict[int, _muon_state],
164
- rank: int,
165
- num_ranks: int,
166
- process_group: dist.ProcessGroup,
167
- computed_us: dict[int, torch.Tensor | None],
168
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
169
- """Allocate scatter buffers, build send/recv, and launch async all-to-all.
170
-
171
- Returns:
172
- work: Async operation handle.
173
- recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
- scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
- zero-copy views into ``recv_buf``.
176
- recv_counts: Per-source-rank element counts.
177
- """
178
- # scattered_us is populated by _complete_scatter with zero-copy views
179
- # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
- # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
- # so _update_params can iterate all params without KeyError.
182
- scattered_us: dict[int, torch.Tensor] = {}
183
- for p in params:
184
- if param_to_state[id(p)].rank_numels[rank] == 0:
185
- scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
- dtype=COMM_DTYPE)
187
-
188
- # Build send buffer – batch via torch.cat
189
- # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
- send_counts = [0] * num_ranks
191
- if owned_params:
192
- for p in owned_params:
193
- state = param_to_state[id(p)]
194
- for dst_rank in range(num_ranks):
195
- send_counts[dst_rank] += state.rank_numels[dst_rank]
196
-
197
- total_send = sum(send_counts)
198
- if total_send > 0:
199
- # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
- u_fulls = {}
201
- for p in owned_params:
202
- u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
-
204
- # Collect slices in dst order (matches all-to-all send layout).
205
- all_slices = []
206
- for dst_rank in range(num_ranks):
207
- for p in owned_params:
208
- state = param_to_state[id(p)]
209
- su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
- if su.numel() > 0:
211
- all_slices.append(su)
212
-
213
- send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
- 0, dtype=COMM_DTYPE, device="cuda")
215
- else:
216
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
-
218
- # Build recv buffer
219
- recv_counts = [0] * num_ranks
220
- for src in range(num_ranks):
221
- total = 0
222
- for p in params:
223
- state = param_to_state[id(p)]
224
- if state.worker_rank != src:
225
- continue
226
- total += state.rank_numels[rank]
227
- recv_counts[src] = total
228
-
229
- recv_total = sum(recv_counts)
230
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
-
232
- # Launch async all-to-all
233
- work = dist.all_to_all_single(
234
- recv_buf,
235
- send_buf,
236
- output_split_sizes=recv_counts,
237
- input_split_sizes=send_counts,
238
- group=process_group,
239
- async_op=True,
240
- )
241
-
242
- return work, recv_buf, scattered_us, recv_counts
243
-
244
-
245
- def _complete_scatter(
246
- recv_buf: torch.Tensor,
247
- recv_counts: list[int],
248
- params: list[DTensor],
249
- param_to_state: dict[int, _muon_state],
250
- rank: int,
251
- scattered_us: dict[int, torch.Tensor],
252
- ) -> None:
253
- """Populate scattered_us with zero-copy views into recv_buf.
254
-
255
- Instead of pre-allocating tensors and copying, we assign views directly
256
- from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
- The underlying storage of ``recv_buf`` is kept alive through the views
258
- until ``scattered_us`` is cleared after ``_update_params``.
259
- """
260
- off = 0
261
- for src in range(len(recv_counts)):
262
- block = recv_counts[src]
263
- if block == 0:
264
- continue
265
-
266
- inner_off = 0
267
- for p in params:
268
- state = param_to_state[id(p)]
269
- if state.worker_rank != src:
270
- continue
271
- n = state.rank_numels[rank]
272
- if n == 0:
273
- continue
274
-
275
- scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
- n).view_as(p.to_local())
277
-
278
- inner_off += n
279
-
280
- assert inner_off == block
281
- off += block
282
-
283
-
284
- def _update_params(
285
- params: list[DTensor],
286
- param_to_state: dict[int, _muon_state],
287
- rank: int,
288
- scattered_us: dict[int, torch.Tensor],
289
- lr: float,
290
- weight_decay: float,
291
- ) -> None:
292
- """Apply weight decay, Muon update, and optional QK clipping.
293
-
294
- Uses batched ``_foreach_mul_`` for weight decay and batched
295
- ``_foreach_add_`` for the Muon update, grouping parameters by
296
- adjusted_lr to minimize kernel launches while preserving float32
297
- precision for the alpha scaling.
298
- """
299
- if not params:
300
- return
301
-
302
- # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
- p_locals = [p._local_tensor for p in params]
304
- torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
-
306
- # Group params by adjusted_lr so _foreach_add_ can use a single
307
- # alpha per group (preserves float32 precision for alpha scaling).
308
- lr_groups: dict[float, tuple[list, list]] = {}
309
- for p in params:
310
- adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
- if adjusted_lr not in lr_groups:
312
- lr_groups[adjusted_lr] = ([], [])
313
- lr_groups[adjusted_lr][0].append(p._local_tensor)
314
- lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
-
316
- for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
- torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
-
319
- # QK clipping – applied directly on the local tensor to
320
- # avoid DTensor sharding-propagation issues with _StridedShard.
321
- for p in params:
322
- state = param_to_state[id(p)]
323
- if state.qk_clip_state is None:
324
- continue
325
- scales_full = compute_scales(p, state.qk_clip_state)
326
- if scales_full is not None:
327
- ratio = p.shape[0] // scales_full.shape[0]
328
- idx0 = state.rank_indices[rank][0]
329
- if isinstance(idx0, slice):
330
- start = idx0.start or 0
331
- idx0 = torch.arange(start,
332
- idx0.stop,
333
- device=scales_full.device)
334
- row_scales = scales_full[idx0 // ratio]
335
- p._local_tensor.mul_(row_scales.view(-1, 1))
336
-
337
-
338
- # ======================================================================
339
- # Pre-launch helper for overlapping first chunk's gather with other work.
340
- # ======================================================================
341
-
342
-
343
- @torch.no_grad()
344
- def prelaunch_first_gather(
345
- params: list[DTensor],
346
- param_to_state: dict[int, _muon_state],
347
- rank: int,
348
- none_grad: bool,
349
- ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
- """Launch the first chunk's A2A gather early for overlap with other compute.
351
-
352
- Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
- the NCCL all-to-all runs concurrently on the NCCL stream while the
354
- default stream executes compute.
355
-
356
- Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
- be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
- """
359
- process_group = param_to_state[id(params[0])].process_group
360
- num_ranks = dist.get_world_size(group=process_group)
361
- owned_params = [
362
- p for p in params if param_to_state[id(p)].worker_rank == rank
363
- ]
364
-
365
- with record_function("muon::prelaunch_gather"):
366
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
- params, owned_params, param_to_state, rank, num_ranks,
368
- process_group)
369
-
370
- if none_grad:
371
- for p in params:
372
- p.grad = None
373
-
374
- return work, recv_buf, gathered_grads, recv_counts
375
-
376
-
377
- # ======================================================================
378
- # Main generator – thin orchestrator that wires stages together.
379
- # ======================================================================
380
-
381
-
382
- @torch.no_grad()
383
- def muon_chunk_pipeline(
384
- params: list[DTensor],
385
- param_to_state: dict[int, _muon_state],
386
- rank: int,
387
- ns_steps: int,
388
- lr: float,
389
- weight_decay: float,
390
- none_grad: bool,
391
- prelaunch_gather: tuple | None = None,
392
- ) -> Generator[None, None, None]:
393
- """Process one chunk of parameters through the full Muon pipeline.
394
-
395
- Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
396
-
397
- Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
398
- that communication and computation overlap across chunks. Async
399
- communication is launched via ``async_op=True`` and completed after
400
- the yield with ``work.wait()``.
401
-
402
- Overlap happens because :func:`run_pipeline` admits one new chunk
403
- per iteration (staggered admission). While chunk *N* does NS
404
- compute on the default CUDA stream, chunk *N+1*'s async all-to-all
405
- runs concurrently on the NCCL stream — no separate ``comm_stream``
406
- is required.
407
-
408
- If ``prelaunch_gather`` is provided, the gather was already launched
409
- by :func:`prelaunch_first_gather` and we skip launching it again.
410
-
411
- Yields exactly **2** times:
412
-
413
- 1. After launching async all-to-all gather (or immediately if pre-launched).
414
- 2. After launching async all-to-all scatter.
415
- """
416
- process_group = param_to_state[id(params[0])].process_group
417
- num_ranks = dist.get_world_size(group=process_group)
418
- owned_params = [
419
- p for p in params if param_to_state[id(p)].worker_rank == rank
420
- ]
421
-
422
- if prelaunch_gather is not None:
423
- # Gather was pre-launched; none_grad already handled by caller.
424
- work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
- else:
426
- # Normal path: launch async gather.
427
- with record_function("muon::launch_gather"):
428
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
- params, owned_params, param_to_state, rank, num_ranks,
430
- process_group)
431
-
432
- if none_grad:
433
- for p in params:
434
- p.grad = None
435
-
436
- yield # --- YIELD 1: other chunks can launch their gather ---
437
-
438
- with record_function("muon::wait_gather"):
439
- work.wait()
440
- _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
441
- param_to_state, rank)
442
- del recv_buf
443
-
444
- # Stage 3: Newton-Schulz orthogonalization.
445
- with record_function("muon::newton_schulz"):
446
- computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
447
- gathered_grads.clear()
448
-
449
- # Stages 4-5: launch async scatter.
450
- with record_function("muon::launch_scatter"):
451
- work, recv_buf, scattered_us, recv_counts = _launch_scatter(
452
- params, owned_params, param_to_state, rank, num_ranks,
453
- process_group, computed_us)
454
- computed_us.clear()
455
-
456
- yield # --- YIELD 2: other chunks can launch their scatter ---
457
-
458
- with record_function("muon::wait_scatter"):
459
- work.wait()
460
- _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
461
- scattered_us)
462
- del recv_buf
463
-
464
- # Stage 6: apply parameter updates.
465
- with record_function("muon::update_params"):
466
- _update_params(params, param_to_state, rank, scattered_us, lr,
467
- weight_decay)
468
- scattered_us.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py DELETED
@@ -1,198 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
-
8
- from .core import normalize_fqn
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
- """
15
- Parse a parameter name to check if it is a query/key projection layer
16
- and return (kind, layer_index).
17
-
18
- Supported kinds:
19
- MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
- MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
-
22
- Returns:
23
- (kind, layer_idx) or (None, -1) if not matched.
24
-
25
- Example:
26
- 'model.3.attn.wq.weight' -> ('wq', 3)
27
- 'model.5.attn.wk.weight' -> ('wk', 5)
28
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
- 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
- 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
- 'model.4.attn.v_proj.weight' -> (None, -1)
33
- """
34
- parts = normalize_fqn(name).split('.')
35
- if len(parts) < 3:
36
- return None, -1
37
-
38
- kind = parts[-2]
39
-
40
- layer_idx = -1
41
- for part in reversed(parts):
42
- if part.isdigit():
43
- layer_idx = int(part)
44
- break
45
-
46
- if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
- return kind, layer_idx
48
-
49
- return None, -1
50
-
51
-
52
- @dataclass
53
- class QKClipInfo:
54
- """Per-parameter dynamic info computed from config + runtime logits."""
55
- kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
- indices: list[int] # which heads to consider for clipping
57
- head_dim: int # from config (qk_head_dim for MLA wq_b)
58
- threshold: float # from config
59
- logit: torch.Tensor | None
60
-
61
- # MLA-specific fields
62
- is_mla: bool = False
63
- qk_nope_head_dim: int = 0
64
- qk_rope_head_dim: int = 0
65
- v_head_dim: int = 0
66
-
67
-
68
- def get_qk_clip_info(clip_config, n, qk_logits):
69
- """Extract QK clipping info for a named parameter.
70
-
71
- Args:
72
- clip_config: QK clipping configuration dict (or None).
73
- MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
- MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
- n: Parameter name string.
76
- qk_logits: Dict mapping layer indices to logit tensors (or None).
77
-
78
- Returns:
79
- QKClipInfo instance with clipping configuration for this parameter.
80
- """
81
- if clip_config is None:
82
- return None
83
-
84
- head_dim = clip_config.get('head_dim')
85
- threshold = clip_config.get('threshold')
86
- kind, layer_idx = parse_qk_layer(n)
87
- is_mla = clip_config.get('is_mla', False)
88
-
89
- logit, indices = None, []
90
- if qk_logits is not None and kind is not None:
91
- logit = qk_logits[layer_idx]
92
- if isinstance(logit, DTensor):
93
- # In TP settings, qk_logits may be DTensor
94
- # We convert it to full tensor here for simplicity
95
- logit = logit.full_tensor()
96
-
97
- if kind in ('wq_b', 'wq', 'q_proj'):
98
- indices = clip_config.get('q_indices', []) or []
99
- elif kind in ('wkv_b', 'wk', 'k_proj'):
100
- indices = clip_config.get('k_indices', []) or []
101
-
102
- if is_mla:
103
- return QKClipInfo(
104
- kind=kind,
105
- indices=indices,
106
- head_dim=head_dim,
107
- threshold=threshold,
108
- logit=logit,
109
- is_mla=True,
110
- qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
- qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
- v_head_dim=clip_config['v_head_dim'],
113
- )
114
- else:
115
- return QKClipInfo(
116
- kind=kind,
117
- indices=indices,
118
- head_dim=head_dim,
119
- threshold=threshold,
120
- logit=logit,
121
- )
122
-
123
-
124
- def compute_scales(p, qk_clip_state):
125
- """Compute per-head scaling factors for QK clipping.
126
-
127
- Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
- For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
- """
130
- kind = qk_clip_state.kind
131
- indices = qk_clip_state.indices
132
- head_dim = qk_clip_state.head_dim
133
- threshold = qk_clip_state.threshold
134
- logit = qk_clip_state.logit
135
-
136
- # Check if any head exceeds threshold before allocating.
137
- head_scales = {}
138
- for logit_idx, head_idx in enumerate(indices):
139
- v_ele = float(logit[logit_idx])
140
- if v_ele > threshold:
141
- new_scale = math.sqrt(threshold / v_ele)
142
- if head_idx not in head_scales or new_scale < head_scales[head_idx]:
143
- head_scales[head_idx] = new_scale
144
- logger.info(
145
- f"[{kind}] Head {head_idx} exceeded threshold "
146
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
147
- )
148
-
149
- if not head_scales:
150
- return None
151
-
152
- # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
- if qk_clip_state.is_mla and kind == 'wkv_b':
154
- effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
- else:
156
- effective_head_dim = head_dim
157
-
158
- H_global = p.shape[0] // effective_head_dim
159
- scales_full = torch.ones(H_global, device=p.data.device)
160
- for head_idx, scale in head_scales.items():
161
- scales_full[head_idx] = scale
162
- return scales_full
163
-
164
-
165
- def qk_clip(p, scales, info):
166
- """Apply per-head scaling to a Q/K projection weight matrix.
167
-
168
- Args:
169
- p: Parameter (nn.Parameter or raw tensor).
170
- scales: [n_heads] tensor, each element = √γ_h.
171
- info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
-
173
- MLA sub-region scaling per Algorithm 1 (MuonClip):
174
- wq_b: q_nope rows → √γ, q_pe rows → γ
175
- wkv_b: k_nope rows → √γ, v rows → unchanged
176
- """
177
- W = p.data if isinstance(p, torch.nn.Parameter) else p
178
-
179
- if not info.is_mla:
180
- # MHA/GQA: uniform √γ applied to all rows in each head
181
- W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
- return
183
-
184
- # MLA: vectorized sub-region scaling within each head
185
- if info.kind == 'wq_b':
186
- qk_nope = info.qk_nope_head_dim
187
- qk_head_dim = qk_nope + info.qk_rope_head_dim
188
- W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ
190
- W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
- 1)) # q_pe → γ
192
-
193
- elif info.kind == 'wkv_b':
194
- qk_nope = info.qk_nope_head_dim
195
- kv_stride = qk_nope + info.v_head_dim
196
- W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
- W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ
198
- # v rows: not touched (k_R shared rotary unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/adamw.py DELETED
@@ -1,271 +0,0 @@
1
- import logging
2
- from collections import defaultdict
3
- from typing import cast
4
-
5
- import torch
6
- from torch.distributed.tensor import DTensor
7
- from torch.profiler import record_function
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def fused_adamw(
13
- params: list[torch.Tensor],
14
- grads: list[torch.Tensor],
15
- exp_avgs: list[torch.Tensor],
16
- exp_avg_sqs: list[torch.Tensor],
17
- max_exp_avg_sqs: list[torch.Tensor],
18
- state_steps: list[torch.Tensor],
19
- amsgrad: bool,
20
- beta1: float,
21
- beta2: float,
22
- lr: float | torch.Tensor,
23
- weight_decay: float,
24
- eps: float,
25
- maximize: bool,
26
- ) -> None:
27
- if not params:
28
- return
29
-
30
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
31
- # treating it as a scalar.
32
- lr_dict: dict | None = ({
33
- lr.device: lr
34
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
35
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
36
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
37
- state_steps] # type: ignore[list-item]
38
- )
39
- for (device, _), (
40
- (
41
- device_params_,
42
- device_grads_,
43
- device_exp_avgs_,
44
- device_exp_avg_sqs_,
45
- device_max_exp_avg_sqs,
46
- device_state_steps_,
47
- ),
48
- _,
49
- ) in grouped_tensors.items():
50
- device_params = cast(list[torch.Tensor], device_params_)
51
- device_grads = cast(list[torch.Tensor], device_grads_)
52
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
53
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
54
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
55
-
56
- if lr_dict is not None and device not in lr_dict:
57
- lr_dict[device] = lr.to(
58
- device=device, non_blocking=True) # type: ignore[union-attr]
59
- lr = lr_dict[device]
60
- torch._foreach_add_(device_state_steps, 1)
61
- func = torch._fused_adamw_
62
- func(
63
- device_params,
64
- device_grads,
65
- device_exp_avgs,
66
- device_exp_avg_sqs,
67
- device_max_exp_avg_sqs, # type: ignore[arg-type]
68
- device_state_steps,
69
- amsgrad=amsgrad,
70
- lr=lr, # type: ignore[arg-type]
71
- beta1=beta1,
72
- beta2=beta2,
73
- weight_decay=weight_decay,
74
- eps=eps,
75
- maximize=maximize,
76
- )
77
-
78
-
79
- def _to_local(t):
80
- """Unwrap DTensor to local tensor for fused ops."""
81
- return t._local_tensor if isinstance(t, DTensor) else t
82
-
83
-
84
- # ---------------------------------------------------------------------------
85
- # Caches for eliminating per-step Python overhead.
86
- #
87
- # Placement grouping and tensor list assembly are identical every step
88
- # (params don't change placement, moment/step tensors are the same objects
89
- # after initialisation). We cache them keyed by id() of the param list
90
- # stored in param_groups (stable across steps).
91
- #
92
- # Only gradients change each step and must be collected fresh.
93
- # ---------------------------------------------------------------------------
94
-
95
- # id(group["params"]) → dict[placement_key, list[param]]
96
- _placement_cache: dict[int, dict[tuple, list]] = {}
97
-
98
- # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
- _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
-
101
-
102
- def _step_adamw_params_slow(optimizer_state, params, group):
103
- """Uncached fallback for the rare case where some params lack grads."""
104
- params_with_grads = []
105
- grads = []
106
- moment1 = []
107
- moment2 = []
108
- state_steps = []
109
-
110
- for p in params:
111
- g = p.grad
112
- if g is None:
113
- continue
114
- state = optimizer_state[p]
115
- params_with_grads.append(_to_local(p))
116
- grads.append(_to_local(g))
117
- if "step" not in state:
118
- state["step"] = torch.zeros((),
119
- dtype=torch.float32,
120
- device=p.device)
121
- state["moment1"] = torch.zeros_like(g)
122
- state["moment2"] = torch.zeros_like(g)
123
- moment1.append(_to_local(state["moment1"]))
124
- moment2.append(_to_local(state["moment2"]))
125
- if not isinstance(state["step"], torch.Tensor):
126
- state["step"] = torch.tensor(state["step"],
127
- dtype=torch.float32,
128
- device=p.device)
129
- state_steps.append(state["step"])
130
-
131
- if not params_with_grads:
132
- return
133
-
134
- lr = group["lr"]
135
- beta1, beta2 = group["adamw_betas"]
136
- eps = group["adamw_eps"]
137
- weight_decay = group["weight_decay"]
138
-
139
- fused_adamw(
140
- params_with_grads,
141
- grads,
142
- moment1,
143
- moment2,
144
- [],
145
- state_steps,
146
- amsgrad=False,
147
- beta1=beta1,
148
- beta2=beta2,
149
- lr=lr,
150
- weight_decay=weight_decay,
151
- eps=eps,
152
- maximize=False,
153
- )
154
-
155
-
156
- def step_adamw_params(optimizer_state, params, group):
157
- """Run fused AdamW on a list of parameters sharing the same placement.
158
-
159
- After the first call, cached tensor lists (params_local, moment1,
160
- moment2, state_steps) are reused — only gradients are collected fresh.
161
-
162
- Args:
163
- optimizer_state: The optimizer's state dict (self.state in Muon).
164
- params: List of parameters to update.
165
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
- """
167
- # Collect grads — the only thing that changes each step.
168
- with record_function("adamw::collect_grads"):
169
- grads = []
170
- for p in params:
171
- g = p.grad
172
- if g is None:
173
- # Rare: fall back to slow path that filters per-param.
174
- _step_adamw_params_slow(optimizer_state, params, group)
175
- return
176
- grads.append(_to_local(g))
177
-
178
- tensor_key = id(params)
179
- if tensor_key not in _tensor_cache:
180
- with record_function("adamw::init_tensor_cache"):
181
- params_local = []
182
- moment1 = []
183
- moment2 = []
184
- state_steps = []
185
-
186
- for p in params:
187
- state = optimizer_state[p]
188
- params_local.append(_to_local(p))
189
- if "step" not in state:
190
- state["step"] = torch.zeros((),
191
- dtype=torch.float32,
192
- device=p.device)
193
- state["moment1"] = torch.zeros_like(p.grad)
194
- state["moment2"] = torch.zeros_like(p.grad)
195
- moment1.append(_to_local(state["moment1"]))
196
- moment2.append(_to_local(state["moment2"]))
197
- if not isinstance(state["step"], torch.Tensor):
198
- state["step"] = torch.tensor(state["step"],
199
- dtype=torch.float32,
200
- device=p.device)
201
- state_steps.append(state["step"])
202
-
203
- _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
- state_steps)
205
-
206
- params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
-
208
- lr = group["lr"]
209
- beta1, beta2 = group["adamw_betas"]
210
- eps = group["adamw_eps"]
211
- weight_decay = group["weight_decay"]
212
-
213
- with record_function("adamw::fused_adamw"):
214
- fused_adamw(
215
- params_local,
216
- grads,
217
- moment1,
218
- moment2,
219
- [],
220
- state_steps,
221
- amsgrad=False,
222
- beta1=beta1,
223
- beta2=beta2,
224
- lr=lr,
225
- weight_decay=weight_decay,
226
- eps=eps,
227
- maximize=False,
228
- )
229
-
230
-
231
- def step_adamw(optimizer_state, group):
232
- """Dispatch AdamW step, grouping parameters by type and placement.
233
-
234
- Placement grouping is cached after the first call since params never
235
- change their placement between steps.
236
-
237
- Args:
238
- optimizer_state: The optimizer's state dict (self.state in Muon).
239
- group: Parameter group dict.
240
- """
241
- params = group["params"]
242
- placement_key = id(params)
243
-
244
- if placement_key not in _placement_cache:
245
- with record_function("adamw::group_by_placement"):
246
- placement_to_params: dict[tuple,
247
- list[torch.Tensor]] = defaultdict(list)
248
- for p in params:
249
- match p:
250
- case DTensor():
251
- logger.debug(
252
- "[AdamW] DTensor param: shape=%s, placements=%s, "
253
- "mesh=%s, grad=%s", p.shape, p.placements,
254
- p.device_mesh.mesh_dim_names,
255
- p.grad.shape if p.grad is not None else None)
256
- placement_to_params[tuple(
257
- [p.placements, p.device_mesh])].append(p)
258
- case torch.Tensor():
259
- logger.debug(
260
- "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
- p.grad.shape if p.grad is not None else None)
262
- placement_to_params[tuple([torch.Tensor,
263
- None])].append(p)
264
-
265
- logger.debug("[AdamW] %d placement groups, %d total params",
266
- len(placement_to_params), len(params))
267
-
268
- _placement_cache[placement_key] = dict(placement_to_params)
269
-
270
- for group_params in _placement_cache[placement_key].values():
271
- step_adamw_params(optimizer_state, group_params, group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- from typing import Generator
3
-
4
- logger = logging.getLogger(__name__)
5
-
6
-
7
- class _Task:
8
- """Internal: wraps a generator, advances one yield at a time."""
9
-
10
- def __init__(self, generator: Generator[None, None, None], index: int):
11
- self._generator = generator
12
- self._index = index
13
- self._steps_completed = 0
14
- self.step() # run to first yield
15
-
16
- def step(self) -> bool:
17
- try:
18
- next(self._generator)
19
- self._steps_completed += 1
20
- logger.debug("pipeline[%d] completed stage %d", self._index,
21
- self._steps_completed)
22
- return True
23
- except StopIteration:
24
- logger.debug("pipeline[%d] finished after %d stages", self._index,
25
- self._steps_completed)
26
- return False
27
-
28
- def close(self):
29
- self._generator.close()
30
-
31
-
32
- def run_pipeline(
33
- pipelines: Generator[Generator[None, None, None], None, None],
34
- max_concurrent: int,
35
- ) -> None:
36
- """Run generator-based pipelines with bounded concurrency.
37
-
38
- Each pipeline is a generator that yields at stage boundaries.
39
- The runtime interleaves pipelines so communication and computation
40
- overlap across chunks.
41
- """
42
- if max_concurrent <= 0:
43
- raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}")
44
-
45
- have_new = True
46
- task_index = 0
47
- previous_tasks: list[_Task] = []
48
-
49
- try:
50
- while have_new or previous_tasks:
51
- running_tasks: list[_Task] = []
52
-
53
- # Admit one new pipeline per iteration (staggered admission).
54
- # Admitting one at a time ensures that while chunk N does NS
55
- # compute on the default stream, chunk N+1's NCCL all-to-all
56
- # runs concurrently on the NCCL stream — creating real
57
- # communication/computation overlap on the GPU.
58
- if have_new and len(previous_tasks) < max_concurrent:
59
- try:
60
- gen = next(pipelines)
61
- task = _Task(gen, task_index)
62
- task_index += 1
63
- running_tasks.append(task)
64
- except StopIteration:
65
- have_new = False
66
-
67
- # Advance every previously-yielded task by one step.
68
- for task in previous_tasks:
69
- if task.step():
70
- running_tasks.append(task)
71
-
72
- previous_tasks = running_tasks
73
- except BaseException:
74
- # Clean up all in-flight generators to release GPU resources.
75
- for task in previous_tasks:
76
- task.close()
77
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/core.py DELETED
@@ -1,219 +0,0 @@
1
- import logging
2
- import math
3
- from dataclasses import dataclass
4
- from typing import List
5
-
6
- import torch
7
- from torch.distributed import ProcessGroup
8
- from torch.distributed.tensor import DTensor
9
-
10
- # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
- # parameter FQNs. Activation checkpointing similarly inserts
12
- # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
- # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
- _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def normalize_fqn(name: str) -> str:
20
- """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
- return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
-
23
-
24
- @dataclass
25
- class _muon_state:
26
- worker_rank: int
27
- process_group: ProcessGroup
28
- rank_indices: dict[int, tuple] # local_rank -> per-dim indices
29
- rank_numels: dict[int, int] # local_rank -> numel
30
- name: str
31
- qk_clip_state: torch.Tensor | None = None
32
-
33
-
34
- def _batch_momentum(
35
- grads: List[torch.Tensor],
36
- momentum_bufs: List[torch.Tensor],
37
- momentum: torch.Tensor,
38
- ) -> None:
39
- """Batched momentum update (no nesterov)."""
40
- torch._foreach_mul_(momentum_bufs, momentum)
41
- torch._foreach_add_(momentum_bufs, grads)
42
-
43
-
44
- def _batch_momentum_nesterov(
45
- grads: List[torch.Tensor],
46
- momentum_bufs: List[torch.Tensor],
47
- momentum: torch.Tensor,
48
- ) -> None:
49
- """Batched momentum update with nesterov correction."""
50
- torch._foreach_mul_(momentum_bufs, momentum)
51
- torch._foreach_add_(momentum_bufs, grads)
52
- nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
- torch._foreach_add_(grads, nesterov_terms)
54
-
55
-
56
- _compiled_momentum: dict[bool, callable] = {}
57
- _use_momentum_compile = True
58
-
59
-
60
- def set_momentum_compile(enabled: bool):
61
- """Toggle torch.compile for batched momentum."""
62
- global _use_momentum_compile
63
- _use_momentum_compile = enabled
64
-
65
-
66
- def batch_pre_ortho(
67
- grads: List[torch.Tensor],
68
- momentum_bufs: List[torch.Tensor],
69
- momentum: torch.Tensor,
70
- nesterov: bool,
71
- ) -> None:
72
- """Batched momentum update on lists of plain tensors.
73
-
74
- Mirrors dion's ``muon_update_pre_orthogonalize``.
75
- Inputs must be plain CUDA tensors (not DTensor).
76
- Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
-
78
- When compile is enabled, uses separately compiled functions for
79
- nesterov=True/False to avoid graph breaks from the branch.
80
- """
81
- fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
- if _use_momentum_compile:
83
- if nesterov not in _compiled_momentum:
84
- _compiled_momentum[nesterov] = torch.compile(fn)
85
- fn = _compiled_momentum[nesterov]
86
- fn(grads, momentum_bufs, momentum)
87
-
88
-
89
- def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
- """Weight-decay + update on plain tensors.
91
-
92
- Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
- lookup per call × 256+ params = massive overhead. The pipeline path uses
94
- batched _foreach_* ops instead; this function remains for base() and
95
- distributed_muon().
96
- """
97
- p_data.mul_(1 - lr * weight_decay)
98
- p_data.add_(u_data, alpha=-adjusted_lr)
99
-
100
-
101
- def update_p(p, u, lr, adjusted_lr, weight_decay):
102
- """Apply weight decay and orthogonalized update to parameter.
103
-
104
- Args:
105
- p: Parameter (torch.nn.Parameter or DTensor).
106
- u: Orthogonalized update tensor.
107
- lr: Base learning rate.
108
- adjusted_lr: Size-adjusted learning rate.
109
- weight_decay: Weight decay coefficient.
110
- """
111
- # Unwrap Parameter -> underlying data tensor.
112
- p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
- # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
- if isinstance(p_data, DTensor):
115
- p_data = p_data._local_tensor
116
- u_data = u._local_tensor if isinstance(u, DTensor) else u
117
- _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
118
-
119
-
120
- def adjust_lr_for_muon(lr, param_shape):
121
- """Scale learning rate based on parameter matrix dimensions.
122
-
123
- Args:
124
- lr: Base learning rate.
125
- param_shape: Shape of the parameter tensor.
126
-
127
- Returns:
128
- Adjusted learning rate.
129
- """
130
- A, B = param_shape[:2]
131
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
132
- # as described in the paper
133
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
134
- adjusted_lr = lr * adjusted_ratio
135
- return adjusted_lr
136
-
137
-
138
- def _match_key(parts, key):
139
- """Check if key matches as contiguous components in parts.
140
-
141
- Single-component keys (e.g. "experts") match any single component.
142
- Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
- """
144
- key_parts = key.split(".")
145
- key_len = len(key_parts)
146
- if key_len == 1:
147
- return key in parts
148
- return any(parts[i:i + key_len] == key_parts
149
- for i in range(len(parts) - key_len + 1))
150
-
151
-
152
- def is_expert_param(name, expert_keys):
153
- """Check if a parameter name matches any expert key (component-level)."""
154
- if not expert_keys:
155
- return False
156
- parts = normalize_fqn(name).split(".")
157
- return any(_match_key(parts, key) for key in expert_keys)
158
-
159
-
160
- def default_is_muon(name, x, expert_keys=None):
161
- normalized = normalize_fqn(name)
162
- parts = normalized.split(".")
163
- skip_keys = [
164
- "embed_tokens",
165
- "lm_head",
166
- "tok_embeddings",
167
- "output",
168
- "mhc_attn",
169
- "mhc_ffn",
170
- "lambda_proj",
171
- ]
172
- if any(key in parts for key in skip_keys):
173
- logger.info(
174
- "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
- normalized, name, x.ndim)
176
- return False
177
- effective_ndim = x.ndim
178
- is_expert = is_expert_param(name, expert_keys)
179
- if is_expert:
180
- effective_ndim -= 1
181
- result = effective_ndim >= 2
182
- logger.info(
183
- "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
- normalized, name, x.ndim, is_expert, effective_ndim,
185
- "Muon" if result else "AdamW")
186
- return result
187
-
188
-
189
- def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
190
- if is_muon_func is None:
191
- is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
-
193
- muon_params, muon_names = [], []
194
- non_muon_params, non_muon_names = [], []
195
-
196
- for n, p in model.named_parameters():
197
- if not p.requires_grad:
198
- continue
199
- if is_muon_func(n, p):
200
- muon_params.append(p)
201
- muon_names.append(n)
202
- else:
203
- non_muon_params.append(p)
204
- non_muon_names.append(n)
205
-
206
- logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
- expert_keys, len(muon_names), len(non_muon_names))
208
-
209
- return [
210
- {
211
- "params": muon_params,
212
- "names": muon_names,
213
- "use_muon": True,
214
- },
215
- {
216
- "params": non_muon_params,
217
- "use_muon": False,
218
- },
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py DELETED
@@ -1,206 +0,0 @@
1
- """CPU offloading for optimizer states.
2
-
3
- Manages a pinned CPU memory pool and async CUDA streams to offload
4
- optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
- optimizer steps, freeing GPU memory.
6
-
7
- All tracked tensors are packed into a single flat pinned CPU buffer
8
- (per dtype). D2H and H2D copies are performed per-tensor directly
9
- between individual GPU tensors and their slice of the CPU flat buffer
10
- — no GPU staging buffer is allocated, so there is **no temporary GPU
11
- memory spike** during offload or reload.
12
-
13
- Individual tensor storages are freed after offload via
14
- ``untyped_storage().resize_(0)``, preserving tensor identity so
15
- downstream caches remain valid.
16
- """
17
-
18
- import logging
19
- from collections import defaultdict
20
-
21
- import torch
22
- from torch.distributed.tensor import DTensor
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- class CPUOffloadPool:
28
- """Pinned CPU memory pool for async optimizer state offloading.
29
-
30
- Tracked tensors are grouped by dtype. Each group gets a single flat
31
- pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
- the flat buffer) to avoid allocating a GPU staging buffer.
33
- """
34
-
35
- def __init__(self):
36
- self._managed: list[torch.Tensor] = []
37
- self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
-
39
- # Per-dtype group: populated on first offload.
40
- # dtype → dict with keys:
41
- # "indices" : list[int] managed-list indices
42
- # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
- # "total" : int total numel
44
- # "cpu_flat" : Tensor pinned CPU buffer
45
- self._groups: dict[torch.dtype, dict] = {}
46
-
47
- self._offload_stream: torch.cuda.Stream | None = None
48
- self._device: torch.device | None = None
49
- self._initialized: bool = False
50
- self._logged: bool = False
51
-
52
- # ------------------------------------------------------------------
53
- @staticmethod
54
- def _local(t: torch.Tensor) -> torch.Tensor:
55
- """Unwrap DTensor to its local CUDA tensor."""
56
- return t._local_tensor if isinstance(t, DTensor) else t
57
-
58
- def _ensure_stream(self):
59
- if self._offload_stream is None:
60
- self._offload_stream = torch.cuda.Stream(device=self._device)
61
-
62
- # ------------------------------------------------------------------
63
- def track(self, tensor: torch.Tensor):
64
- """Register a GPU tensor for CPU offloading. Idempotent."""
65
- tid = id(tensor)
66
- if tid in self._storage_nbytes:
67
- return
68
- local = self._local(tensor)
69
- if self._device is None:
70
- self._device = local.device
71
- storage = local.untyped_storage()
72
- # Skip tensors with empty storage (e.g. empty FSDP shards)
73
- if storage.size() == 0:
74
- return
75
- self._storage_nbytes[tid] = storage.size()
76
- self._managed.append(tensor)
77
-
78
- # ------------------------------------------------------------------
79
- def _init_buffers(self):
80
- """Build per-dtype flat buffers on first offload."""
81
- # Group managed tensors by dtype.
82
- dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
83
- for idx, t in enumerate(self._managed):
84
- local = self._local(t)
85
- dtype_map[local.dtype].append((idx, local.numel()))
86
-
87
- total_cpu_bytes = 0
88
- for dtype, entries in dtype_map.items():
89
- offsets: list[tuple[int, int]] = []
90
- indices: list[int] = []
91
- off = 0
92
- for idx, n in entries:
93
- indices.append(idx)
94
- offsets.append((off, n))
95
- off += n
96
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
97
- self._groups[dtype] = {
98
- "indices": indices,
99
- "offsets": offsets,
100
- "total": off,
101
- "cpu_flat": cpu_flat,
102
- }
103
- total_cpu_bytes += off * cpu_flat.element_size()
104
-
105
- self._initialized = True
106
- logger.info(
107
- "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
108
- "%.2f MB pinned CPU memory",
109
- len(self._managed),
110
- len(self._groups),
111
- total_cpu_bytes / (1024**2),
112
- )
113
-
114
- # ------------------------------------------------------------------
115
- def offload(self):
116
- """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
117
- if not self._managed:
118
- return
119
- if not self._initialized:
120
- self._init_buffers()
121
- self._ensure_stream()
122
-
123
- # Offload stream waits for compute to finish.
124
- compute_event = torch.cuda.current_stream(self._device).record_event()
125
- self._offload_stream.wait_event(compute_event)
126
-
127
- offloaded_bytes = 0
128
-
129
- # Per-tensor D2H copies directly into CPU flat buffer slices.
130
- # No GPU staging buffer → no temporary GPU memory spike.
131
- with torch.cuda.stream(self._offload_stream):
132
- for dtype, grp in self._groups.items():
133
- indices = grp["indices"]
134
- offsets = grp["offsets"]
135
- cpu_flat = grp["cpu_flat"]
136
-
137
- for i, mgd_idx in enumerate(indices):
138
- local = self._local(self._managed[mgd_idx])
139
- off, n = offsets[i]
140
- cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
141
-
142
- offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
-
144
- # Wait for all D2H copies to land, then free GPU storage.
145
- self._offload_stream.synchronize()
146
- for t in self._managed:
147
- storage = self._local(t).untyped_storage()
148
- if storage.size() != 0:
149
- storage.resize_(0)
150
- else:
151
- raise RuntimeError(
152
- f"Tensor storage is already freed (size=0) before offload. "
153
- f"This indicates a double-free or external interference. "
154
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
155
- )
156
-
157
- if not self._logged:
158
- logger.info(
159
- "[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
160
- offloaded_bytes / (1024**2),
161
- )
162
-
163
- # ------------------------------------------------------------------
164
- def reload(self):
165
- """Per-tensor H2D from CPU flat buffer on the default stream.
166
-
167
- Runs on the current (default) CUDA stream to avoid stream
168
- interaction issues with the parallel Muon pipeline. Since
169
- pinned CPU memory is the source, the copies overlap with
170
- GPU idle time between steps.
171
- """
172
- if not self._managed or not self._initialized:
173
- return
174
-
175
- reloaded_bytes = 0
176
-
177
- # Re-allocate all GPU storages first.
178
- for t in self._managed:
179
- local = self._local(t)
180
- storage = local.untyped_storage()
181
- if storage.size() != 0:
182
- raise RuntimeError(
183
- f"Storage should have been freed (size=0) before reload, "
184
- f"but got size={storage.size()}. "
185
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
186
- )
187
- storage.resize_(self._storage_nbytes[id(t)])
188
-
189
- # Per-tensor H2D copies from CPU flat buffer slices.
190
- # non_blocking=True with pinned source allows DMA overlap.
191
- for dtype, grp in self._groups.items():
192
- indices = grp["indices"]
193
- offsets = grp["offsets"]
194
- cpu_flat = grp["cpu_flat"]
195
-
196
- for i, mgd_idx in enumerate(indices):
197
- local = self._local(self._managed[mgd_idx])
198
- off, n = offsets[i]
199
- local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
200
-
201
- reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
-
203
- if not self._logged:
204
- logger.info(
205
- "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2)
206
- )