Kernels
This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .github/actionlint.yaml +0 -3
  2. .github/workflows/build-and-commit.yml +0 -120
  3. .github/workflows/pre-commit.yml +0 -30
  4. .github/workflows/push-to-hf.yml +0 -40
  5. .gitignore +0 -21
  6. .pre-commit-config.yaml +0 -33
  7. README.md +4 -69
  8. build.toml +14 -24
  9. build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +0 -175
  10. build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +0 -128
  11. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +0 -1
  12. build/torch210-cxx11-cu126-x86_64-linux/muon.py +0 -1268
  13. build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py +0 -26
  14. build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +0 -175
  15. build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +0 -128
  16. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +0 -1
  17. build/torch210-cxx11-cu128-x86_64-linux/muon.py +0 -1268
  18. build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py +0 -26
  19. build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +0 -175
  20. build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +0 -128
  21. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +0 -1
  22. build/torch210-cxx11-cu130-x86_64-linux/muon.py +0 -1268
  23. build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py +0 -26
  24. build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +0 -175
  25. build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +0 -128
  26. build/torch210-cxx11-rocm70-x86_64-linux/metadata.json +0 -1
  27. build/torch210-cxx11-rocm70-x86_64-linux/muon.py +0 -1268
  28. build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py +0 -26
  29. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +0 -9
  30. build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +0 -3
  31. build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +0 -175
  32. build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py +0 -128
  33. build/torch210-cxx11-rocm71-x86_64-linux/metadata.json +0 -1
  34. build/torch210-cxx11-rocm71-x86_64-linux/muon.py +0 -1268
  35. build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py +0 -26
  36. build/{torch210-cxx11-cu126-x86_64-linux → torch26-cxx11-cu118-x86_64-linux/optimizer}/__init__.py +0 -0
  37. build/{torch210-cxx11-rocm70-x86_64-linux → torch26-cxx11-cu118-x86_64-linux/optimizer}/_ops.py +3 -3
  38. build/{torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} +2 -2
  39. build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +455 -0
  40. build/{torch210-cxx11-cu128-x86_64-linux → torch26-cxx11-cu124-x86_64-linux/optimizer}/__init__.py +0 -0
  41. build/{torch210-cxx11-cu126-x86_64-linux → torch26-cxx11-cu124-x86_64-linux/optimizer}/_ops.py +3 -3
  42. build/{torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} +2 -2
  43. build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +455 -0
  44. build/{torch210-cxx11-cu130-x86_64-linux → torch26-cxx11-cu126-x86_64-linux/optimizer}/__init__.py +0 -0
  45. build/{torch210-cxx11-cu130-x86_64-linux → torch26-cxx11-cu126-x86_64-linux/optimizer}/_ops.py +3 -3
  46. build/{torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} +2 -2
  47. build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +455 -0
  48. build/{torch210-cxx11-rocm70-x86_64-linux → torch26-cxx11-rocm62-x86_64-linux/optimizer}/__init__.py +0 -0
  49. build/{torch210-cxx11-cu128-x86_64-linux → torch26-cxx11-rocm62-x86_64-linux/optimizer}/_ops.py +3 -3
  50. build/{torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} +2 -2
.github/actionlint.yaml DELETED
@@ -1,3 +0,0 @@
1
- self-hosted-runner:
2
- labels:
3
- - docker-builder-01
 
 
 
 
.github/workflows/build-and-commit.yml DELETED
@@ -1,120 +0,0 @@
1
- name: Nix build and commit
2
-
3
- on:
4
- pull_request:
5
- types: [opened, synchronize, reopened]
6
- workflow_dispatch:
7
-
8
- permissions:
9
- contents: write
10
-
11
- jobs:
12
- check-commit:
13
- runs-on: ubuntu-latest
14
- outputs:
15
- skip: ${{ steps.check.outputs.skip }}
16
- steps:
17
- - uses: actions/checkout@v4
18
- with:
19
- fetch-depth: 0
20
- - id: check
21
- run: |
22
- if [ "${{ github.event_name }}" = "pull_request" ]; then
23
- msg=$(git log -1 --pretty=%B "${{ github.event.pull_request.head.sha }}")
24
- else
25
- msg="manual dispatch"
26
- fi
27
- echo "Commit message: $msg"
28
- if echo "$msg" | grep -q '\[skip-build\]'; then
29
- echo "skip=true" >> "$GITHUB_OUTPUT"
30
- else
31
- echo "skip=false" >> "$GITHUB_OUTPUT"
32
- fi
33
-
34
- build_and_commit:
35
- needs: check-commit
36
- if: needs.check-commit.outputs.skip == 'false'
37
- runs-on: docker-builder-01
38
- steps:
39
- - name: Show disk usage
40
- run: df -h
41
-
42
- - name: Notify build start on Slack
43
- id: slack_start
44
- run: |
45
- msg="*Build started* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
46
- response=$(curl -s -X POST \
47
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
48
- -H "Content-type: application/json; charset=utf-8" \
49
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\"}" \
50
- https://slack.com/api/chat.postMessage)
51
- ts=$(echo "$response" | jq -r '.ts')
52
- echo "thread_ts=$ts" >> "$GITHUB_OUTPUT"
53
- echo "$response"
54
-
55
- - name: Checkout repository
56
- uses: actions/checkout@v4
57
- with:
58
- fetch-depth: 0
59
- lfs: true
60
- ref: ${{ github.head_ref || github.ref }}
61
-
62
- - name: Install Nix
63
- uses: cachix/install-nix-action@v31
64
-
65
- - name: Setup huggingface cachix
66
- uses: cachix/cachix-action@v15
67
- with:
68
- name: huggingface
69
-
70
- - name: Clean build directory
71
- run: |
72
- rm -rf build
73
-
74
- - name: Build with Nix
75
- run: |
76
- nix run .#build-and-copy \
77
- --override-input kernel-builder github:huggingface/kernel-builder \
78
- --max-jobs 8 \
79
- -j 8 \
80
- -L
81
-
82
- - name: List built binaries
83
- run: |
84
- ls build
85
-
86
- - name: Commit build artifact
87
- run: |
88
- git config user.name "github-actions[bot]"
89
- git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
90
- git add build/*
91
- git commit -m "Add built binary [skip-build]"
92
-
93
- - name: Push changes
94
- run: |
95
- git push origin HEAD:"$HEAD_REF"
96
- env:
97
- HEAD_REF: ${{ github.head_ref || github.ref }}
98
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
99
-
100
- - name: Notify success on Slack (thread)
101
- if: success()
102
- run: |
103
- ts="${{ steps.slack_start.outputs.thread_ts }}"
104
- msg="*Build succeeded* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
105
- curl -s -X POST \
106
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
107
- -H "Content-type: application/json; charset=utf-8" \
108
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \
109
- https://slack.com/api/chat.postMessage
110
-
111
- - name: Notify failure on Slack (thread)
112
- if: failure()
113
- run: |
114
- ts="${{ steps.slack_start.outputs.thread_ts }}"
115
- msg="*Build failed* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>"
116
- curl -s -X POST \
117
- -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \
118
- -H "Content-type: application/json; charset=utf-8" \
119
- --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \
120
- https://slack.com/api/chat.postMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/pre-commit.yml DELETED
@@ -1,30 +0,0 @@
1
- name: pre-commit
2
-
3
- on:
4
- pull_request:
5
- push:
6
- branches: [ main, master ]
7
-
8
- jobs:
9
- run-pre-commit:
10
- runs-on: ubuntu-latest
11
- permissions:
12
- contents: read
13
- pull-requests: read
14
- steps:
15
- - uses: actions/checkout@v4
16
-
17
- - uses: actions/setup-python@v5
18
- with:
19
- python-version: "3.11"
20
-
21
- - name: Cache pre-commit
22
- uses: actions/cache@v4
23
- with:
24
- path: ~/.cache/pre-commit
25
- key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
26
- restore-keys: |
27
- pre-commit-${{ runner.os }}-
28
-
29
- - name: Run pre-commit
30
- uses: pre-commit/action@v3.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/push-to-hf.yml DELETED
@@ -1,40 +0,0 @@
1
- name: Push to HF Repo
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
- workflow_dispatch:
8
-
9
- jobs:
10
- push_to_hf:
11
- runs-on: ubuntu-latest
12
- steps:
13
- # 1. Checkout the repo
14
- - name: Checkout repository
15
- uses: actions/checkout@v4
16
- with:
17
- fetch-depth: 0
18
- - name: Install Git LFS
19
- run: |
20
- git lfs install
21
- git lfs fetch --all
22
- git lfs pull
23
- # 2. Set up Git
24
- - name: Configure Git
25
- run: |
26
- git config user.name "MotifTech"
27
- git config user.email "huggingface@motiftech.io"
28
-
29
- # 3. Add HF remote
30
- - name: Add Hugging Face remote
31
- run: |
32
- git remote add hf https://huggingface.co/Motif-Technologies/optimizer
33
- git fetch hf || true
34
-
35
- # 4. Push to HF repo
36
- - name: Push to Hugging Face
37
- env:
38
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
39
- run: |
40
- git push "https://hf_token:${HF_TOKEN}@huggingface.co/Motif-Technologies/optimizer" HEAD:main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,21 +0,0 @@
1
- __pycache__
2
- .idea
3
- .DS_Store
4
- *.egg-info
5
- outputs
6
- dist/*
7
- .vscode
8
-
9
- # data
10
- data
11
- out
12
- wandb
13
-
14
- torchtitan/datasets/**/*.model
15
- torchtitan/experiments/flux/assets/*
16
-
17
- # temp files
18
- *.log
19
- error.json
20
- _remote_module_non_scriptable.py
21
- .git_disabled/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.pre-commit-config.yaml DELETED
@@ -1,33 +0,0 @@
1
- default_install_hook_types:
2
- - pre-commit
3
- - commit-msg
4
- default_stages:
5
- - pre-commit # Run locally
6
- - manual # Run in CI
7
- exclude: '(build|result)/.*|__pycache__/.*|.*\.(png|html)$'
8
- repos:
9
- - repo: https://github.com/google/yapf
10
- rev: v0.43.0
11
- hooks:
12
- - id: yapf
13
- args: [--in-place, --verbose]
14
- - repo: https://github.com/crate-ci/typos
15
- rev: v1.34.0
16
- hooks:
17
- - id: typos
18
- exclude: '.gitattributes'
19
- - repo: https://github.com/PyCQA/isort
20
- rev: 6.0.1
21
- hooks:
22
- - id: isort
23
- - repo: https://github.com/pre-commit/mirrors-clang-format
24
- rev: v20.1.3
25
- hooks:
26
- - id: clang-format
27
- types_or: [c++, cuda]
28
- args: [--style=file, --verbose]
29
- - repo: https://github.com/jackdewinter/pymarkdown
30
- rev: v0.9.29
31
- hooks:
32
- - id: pymarkdown
33
- args: [fix]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,7 +1,6 @@
1
  ---
2
  tags:
3
- - kernels
4
- license: apache-2.0
5
  ---
6
 
7
  # Optimizer
@@ -10,14 +9,8 @@ Optimizer is a python package that provides:
10
  - PyTorch implementation of recent optimizer algorithms
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
- ## Currently implemented
14
- - Parallel Muon with N-D sharding
15
- - [arxiv URL](https://arxiv.org/abs/2511.07464)
16
- - Supports **general N-D sharding configurations**
17
- - The implementation is not tied to any specific parallel strategy.
18
- - Verified from basic FSDP2 setups up to hybrid configurations such as
19
- **(2 TP + 2 DP-Replicate + 2 DP-Shard)**.
20
- - Verified configurations can be found in [test_muon.py](./test/test_muon.py)
21
 
22
  ## Usage
23
 
@@ -27,72 +20,14 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27
  from kernels import get_kernel
28
 
29
  optimizer = get_kernel("motif-technologies/optimizer")
30
- get_default_muon_param_groups = optimizer.muon.get_default_muon_param_groups
31
 
32
  model = None # your model here
33
  fsdp_model = FSDP(model)
34
 
35
- # muon, in nature, cannot use 1-d tensor
36
- # we provide helper function to group such tensors
37
- # you can use your own function, if necessary
38
- params = get_default_muon_param_groups(model) # user can write own is_muon_func, if necessary
39
-
40
  optim = optimizer.Muon(
41
- params,
42
  lr=0.01,
43
  momentum=0.9,
44
  weight_decay=1e-4,
45
  )
46
  ```
47
-
48
- ## Test
49
- - Check [test/README.md](./test/README.md) for how to run the tests.
50
-
51
- ## Pre-commit Hooks
52
-
53
- This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
54
-
55
- ### Setup
56
-
57
- 1. Install pre-commit:
58
-
59
- ```bash
60
- pip install pre-commit
61
- ```
62
-
63
- 2. Install the git hooks:
64
-
65
- ```bash
66
- pre-commit install
67
- ```
68
-
69
- Once installed, the configured hooks will run automatically on each commit.
70
-
71
- ### Included Hooks
72
-
73
- The following tools are run via pre-commit:
74
-
75
- - **[yapf](https://github.com/google/yapf)** – Python code formatter
76
- - **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
77
- - **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
78
- - **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
79
- - **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
80
- - **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
81
-
82
- ### Usage
83
-
84
- - Run all checks on the entire codebase:
85
-
86
- ```bash
87
- pre-commit run --all-files
88
- ```
89
-
90
- - Run a specific hook (example: isort):
91
-
92
- ```bash
93
- pre-commit run isort --all-files
94
- ```
95
-
96
- ### Test
97
-
98
- - There is a [simple unittest for Parallel Muon](./test/test_muon/README.md)
 
1
  ---
2
  tags:
3
+ - kernel
 
4
  ---
5
 
6
  # Optimizer
 
9
  - PyTorch implementation of recent optimizer algorithms
10
  - with support for parallelism techniques for efficient large-scale training.
11
 
12
+ ### Currently implemented
13
+ - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
 
 
 
 
 
 
14
 
15
  ## Usage
16
 
 
20
  from kernels import get_kernel
21
 
22
  optimizer = get_kernel("motif-technologies/optimizer")
 
23
 
24
  model = None # your model here
25
  fsdp_model = FSDP(model)
26
 
 
 
 
 
 
27
  optim = optimizer.Muon(
28
+ fsdp_model.parameters(),
29
  lr=0.01,
30
  momentum=0.9,
31
  weight_decay=1e-4,
32
  )
33
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build.toml CHANGED
@@ -1,33 +1,23 @@
1
  [general]
2
  name = "optimizer"
3
- backends = [
4
- "cuda",
5
- "rocm",
6
- ]
7
 
8
  [torch]
9
  src = [
10
- "torch-ext/torch_binding.cpp",
11
- "torch-ext/torch_binding.h",
12
  ]
13
 
14
- [kernel.optimizer]
15
- backend = "cuda"
16
- depends = ["torch"]
17
- src = ["optimizer/dummy.cu"]
18
-
19
- [kernel.optimizer_rocm]
20
  backend = "rocm"
21
- rocm-archs = [
22
- "gfx906",
23
- "gfx908",
24
- "gfx90a",
25
- "gfx940",
26
- "gfx941",
27
- "gfx942",
28
- "gfx1030",
29
- "gfx1100",
30
- "gfx1101",
31
  ]
32
- depends = ["torch"]
33
- src = ["optimizer/dummy.cu"]
 
1
  [general]
2
  name = "optimizer"
3
+ universal = false
 
 
 
4
 
5
  [torch]
6
  src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
  ]
10
 
11
+ [kernel.activation]
 
 
 
 
 
12
  backend = "rocm"
13
+ src = [
14
+ "optimizer/dummy.cu",
15
+ ]
16
+ depends = [ "torch" ]
17
+
18
+ [kernel.activation_cuda]
19
+ backend = "cuda"
20
+ src = [
21
+ "optimizer/dummy.cu",
 
22
  ]
23
+ depends = [ "torch" ]
 
build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py DELETED
@@ -1,175 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def get_slices_of_dtensor(
11
- target: DTensor | torch.Tensor,
12
- local_rank: int,
13
- shard_mesh: DeviceMesh,
14
- shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
- """
17
- Get the slice of local tensor for a given rank from a tensor.
18
- Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
- shard_placements (tuple[Placement]): The shard placements.
23
- """
24
-
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
-
27
- # find the global rank of the local rank in the shard mesh
28
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
-
30
- rank_coords = (shard_mesh.mesh == rank).nonzero()
31
-
32
- assert len(rank_coords) == 1
33
- rank_coords = tuple(rank_coords[0].tolist())
34
-
35
- assert len(rank_coords) == len(shard_placements)
36
-
37
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
- # left-to-right sharding. This is ensured by the sorting logic of
39
- # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
-
44
- num_ranks = shard_mesh.mesh.shape[i]
45
-
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
48
-
49
- if dim_size % num_ranks != 0:
50
- raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
63
-
64
- return tuple(slices)
65
-
66
-
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
- ProcessGroup]] = dict()
69
-
70
-
71
- def construct_shard_mesh(
72
- placements: tuple[Placement],
73
- mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
-
81
- assert mesh.mesh.device.type == 'cpu'
82
-
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
-
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
-
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
121
-
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
124
-
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
-
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
- else:
137
- shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
- assert len(shard_placements) == len(set(shard_placements))
142
-
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
- )
171
-
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,128 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- )
47
- @triton.jit
48
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
- GROUP_SIZE_M: tl.constexpr):
51
- """
52
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
- The code is a simple adaptation from the triton `matmul` tutorial:
54
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
- """
56
- pid = tl.program_id(axis=0)
57
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
- group_id = pid // num_pid_in_group
61
- first_pid_m = group_id * GROUP_SIZE_M
62
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
- pid_n = (pid % num_pid_in_group) // group_size_m
65
- if pid_m > pid_n:
66
- return
67
-
68
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_k = tl.arange(0, BLOCK_SIZE_K)
71
- # we use a & b ptrs to denote different rows of x.
72
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
-
75
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
-
77
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
- a = tl.load(a_ptrs,
79
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
- other=0.0)
81
- b = tl.load(b_ptrs,
82
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
- other=0.0)
84
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
- a_ptrs += BLOCK_SIZE_K * stride_xk
86
- b_ptrs += BLOCK_SIZE_K * stride_xk
87
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
- # https://github.com/triton-lang/triton/issues/2252
89
- c = accumulator.to(x.dtype.element_ty)
90
-
91
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
- tl.store(c_ptrs, c, mask=c_mask)
96
-
97
- # transpose and copy
98
- if pid_m < pid_n:
99
- ct_ptrs = y + stride_ym * offs_cn[:,
100
- None] + stride_yn * offs_cm[None, :]
101
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
-
104
-
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
- d_in = d_in.contiguous()
116
- M, K = d_in.shape
117
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
- M, META['BLOCK_SIZE_M']), )
119
- with torch.cuda.device(d_in.device.index):
120
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
- d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu126-x86_64-linux/muon.py DELETED
@@ -1,1268 +0,0 @@
1
- import logging
2
- import math
3
- import types
4
- from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
-
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
-
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
406
-
407
- assert inner_off == block
408
- off += block
409
-
410
-
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
- """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
- continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
-
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
522
-
523
- return None, -1
524
-
525
-
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
-
535
-
536
- class Muon(torch.optim.Optimizer):
537
- """
538
- Muon - MomentUm Orthogonalized by Newton-schulz
539
-
540
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
- the advantage that it can be stably run in bfloat16 on the GPU.
544
-
545
- Some warnings:
546
- - We believe this optimizer is unlikely to work well for training with small batch size.
547
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
-
549
- Arguments:
550
- model: The model to be optimized by Muon.
551
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
- weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
- adamw_lr: The learning rate for the internal AdamW.
559
- adamw_betas: The betas for the internal AdamW.
560
- adamw_eps: The epsilon for the internal AdamW.
561
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
- debug: Whether to print debug information.
563
- clip_info : Configuration for QK clipping. Expected keys:
564
- - "q_indices" (list[int]): Indices of query heads to consider.
565
- - "k_indices" (list[int]): Indices of key heads to consider.
566
- - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
- this value will be scaled down.
569
- Default is:
570
- {
571
- "q_indices": [],
572
- "k_indices": [],
573
- "head_dim": 128,
574
- "threshold": 100
575
- }
576
- warmup_step : How many all2all gather, compute operations are launched in advance
577
- before the corresponding all2all scatter steps begin.
578
- A higher warmup_step increases memory usage but can improve
579
- performance by overlapping communication.
580
- Parallel muon only.
581
- chunk_size : Batch size of parameters to process in each
582
- all2all gather/compute/scatter step.
583
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
- For testing purpose only.
586
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
- """
588
-
589
- def __init__(self,
590
- params,
591
- lr=1e-3,
592
- momentum=0.95,
593
- nesterov=True,
594
- ns_steps=5,
595
- weight_decay=0.1,
596
- adamw_betas=(0.9, 0.95),
597
- adamw_eps=1e-8,
598
- none_grad=True,
599
- debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
- warmup_step=5,
607
- chunk_size=-1,
608
- use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
610
- defaults = dict(
611
- lr=lr,
612
- weight_decay=weight_decay,
613
- momentum=momentum,
614
- nesterov=nesterov,
615
- ns_steps=ns_steps,
616
- adamw_betas=adamw_betas,
617
- adamw_eps=adamw_eps,
618
- none_grad=none_grad,
619
- use_muon=True,
620
- )
621
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
-
624
- if isinstance(params, types.GeneratorType):
625
- raise ValueError(error_message.format(idx=0) + instruction_code)
626
- for _idx, param_group in enumerate(params):
627
- if param_group.get("use_muon", None) is None:
628
- raise ValueError(
629
- error_message.format(idx=_idx) + instruction_code)
630
-
631
- super().__init__(params, defaults)
632
-
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
- self.debug = debug
638
- self.clip_config = clip_config
639
- self.warmup_step = warmup_step
640
- self.chunk_size = chunk_size
641
- self.use_distributed_muon = use_distributed_muon
642
- self.small_param_numel_threshold = small_param_numel_threshold
643
-
644
- def _calc_flops(self, G, steps):
645
- assert len(G.shape) == 2
646
- M, N = G.shape
647
- if M > N:
648
- M, N = N, M
649
-
650
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
-
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
- def get_shard_mesh(self, p):
667
- """
668
- Get the shard mesh for a parameter p on the given rank.
669
- """
670
- assert isinstance(
671
- p, DTensor), "Parallel Muon only supports DTensor parameters."
672
-
673
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
- p.placements, p.device_mesh)
675
-
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
- return shard_mesh, shard_pg, shard_placements
680
-
681
- def init_state_and_assign_params(self, names, params, group, qk_logits):
682
- param_to_state = {}
683
- param_to_flops = {}
684
-
685
- total_flops = 0
686
- for p in params:
687
- g = p.grad
688
- if g is None:
689
- continue
690
- assert g.ndim == 2, "Muon only supports 2D parameters."
691
-
692
- flops = self._calc_flops(g, group["ns_steps"])
693
- param_to_flops[id(p)] = flops
694
- total_flops += flops
695
-
696
- if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
-
700
- paired = list(zip(names, params))
701
-
702
- paired_sorted = sorted(paired,
703
- key=lambda x: param_to_flops[id(x[1])],
704
- reverse=True)
705
-
706
- names_sorted, params_sorted = zip(*paired_sorted)
707
- ordered_names = list(names_sorted)
708
- ordered_params = list(params_sorted)
709
-
710
- round_robin = 0
711
- mesh = ordered_params[0].device_mesh
712
- placements = ordered_params[0].placements
713
-
714
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
- ordered_params[0])
716
- shard_mesh_flattened = shard_mesh.mesh.flatten()
717
- num_ranks = dist.get_world_size(group=shard_pg)
718
-
719
- for n, p in zip(ordered_names, ordered_params):
720
- if mesh != p.device_mesh:
721
- raise ValueError("All parameters must be on the same mesh.")
722
- if placements != p.placements:
723
- raise ValueError("All parameters must have same placements.")
724
-
725
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
-
729
- param_to_state[id(p)] = _muon_state(
730
- worker_rank=worker_rank,
731
- process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
- name=n,
735
- qk_clip_state=qk_clip_state,
736
- )
737
-
738
- return param_to_state, ordered_params
739
-
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
- for n, p in zip(names, params):
744
- g = p.grad
745
- if g is None:
746
- continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
-
753
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
- steps=group["ns_steps"])
755
-
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
-
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
-
761
- scales_full = self._compute_scales(
762
- p, qk_clip_state) if qk_clip_state is not None else None
763
- if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
-
766
- def distributed_muon(
767
- self,
768
- names: list[str],
769
- params: list[torch.nn.Parameter],
770
- group: dict[str, Any],
771
- lr: float,
772
- weight_decay: float,
773
- momentum: float,
774
- qk_logits: list[torch.Tensor | DTensor] | None,
775
- ):
776
- """ Implementation of Distributed Muon by Liu et al. """
777
-
778
- for n, p in zip(names, params):
779
- g = p.grad
780
- if g is None:
781
- continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
-
788
- # Gather G
789
- if isinstance(p.data, DTensor):
790
- g_full = g.full_tensor()
791
- p_full = p.data.full_tensor()
792
- else:
793
- g_full = g
794
- p_full = p
795
-
796
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
- steps=group["ns_steps"])
798
-
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
-
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
-
804
- scales_full = self._compute_scales(
805
- p_full, qk_clip_state) if qk_clip_state is not None else None
806
-
807
- if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
-
810
- if isinstance(p.data, DTensor):
811
- ndims = len(p.device_mesh.mesh.shape)
812
- p_replicate = DTensor.from_local(
813
- p_full,
814
- device_mesh=p.device_mesh,
815
- placements=[Replicate() for _ in range(ndims)],
816
- )
817
-
818
- p_sharded = p_replicate.redistribute(
819
- device_mesh=p.device_mesh,
820
- placements=p.placements,
821
- )
822
-
823
- p.copy_(p_sharded)
824
-
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
- """
911
- Perform a parallel optimization step using Muon.
912
- """
913
-
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
-
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
-
930
- param_to_state, ordered_params = self.init_state_and_assign_params(
931
- names, params, group, qk_logits)
932
-
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
-
967
- if self.chunk_size == -1:
968
- shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
- elif self.chunk_size > 0:
972
- chunk_size = self.chunk_size
973
- else:
974
- raise ValueError("chunk_size must be -1 or a positive integer.")
975
-
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
-
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
-
1064
- def _step_muon(self, group, qk_logits=None):
1065
- params = group["params"]
1066
- lr = group["lr"]
1067
- weight_decay = group["weight_decay"]
1068
- momentum = group["momentum"]
1069
- names = group["names"]
1070
-
1071
- param_dtensors = []
1072
- name_dtensors = []
1073
-
1074
- param_tensors = []
1075
- name_tensors = []
1076
-
1077
- param_dtensors_small = []
1078
- name_dtensors_small = []
1079
-
1080
- if self.use_distributed_muon:
1081
- self.distributed_muon(names=names,
1082
- params=params,
1083
- group=group,
1084
- lr=lr,
1085
- weight_decay=weight_decay,
1086
- momentum=momentum,
1087
- qk_logits=qk_logits)
1088
- return
1089
-
1090
- # For simplicity, we use distributed Muon for small parameters
1091
- # whose number of elements is below a threshold.
1092
- for n, p in zip(names, params):
1093
- if p is None or p.grad is None:
1094
- continue
1095
- if isinstance(p.data, DTensor):
1096
- if all(
1097
- isinstance(placement, Replicate)
1098
- for placement in p.placements):
1099
- param_tensors.append(p)
1100
- name_tensors.append(n)
1101
- elif p.data.numel() <= self.small_param_numel_threshold:
1102
- param_dtensors_small.append(p)
1103
- name_dtensors_small.append(n)
1104
- else:
1105
- param_dtensors.append(p)
1106
- name_dtensors.append(n)
1107
- elif isinstance(p.data, torch.Tensor):
1108
- param_tensors.append(p)
1109
- name_tensors.append(n)
1110
- else:
1111
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
-
1113
- logger.debug(
1114
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
- f"{len(param_dtensors_small)} Small DTensors")
1116
-
1117
- def group_dtensors(dtensors, names):
1118
- # To support different placements, we group parameters by placements
1119
- # and run parallel Muon on each group.
1120
-
1121
- placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
-
1124
- assert len(dtensors) == len(names)
1125
- for p, n in zip(dtensors, names):
1126
- placement_to_params[tuple([p.placements,
1127
- p.device_mesh])][0].append(n)
1128
- placement_to_params[tuple([p.placements,
1129
- p.device_mesh])][1].append(p)
1130
- return placement_to_params
1131
-
1132
- if len(param_dtensors_small) > 0:
1133
- if not dist.is_initialized():
1134
- raise RuntimeError(
1135
- "Parallel Muon requires torch.distributed to be initialized."
1136
- )
1137
-
1138
- self.distributed_muon(
1139
- params=param_dtensors_small,
1140
- names=name_dtensors_small,
1141
- group=group,
1142
- lr=lr,
1143
- weight_decay=weight_decay,
1144
- momentum=momentum,
1145
- qk_logits=qk_logits,
1146
- )
1147
-
1148
- if len(param_dtensors) > 0:
1149
- if not dist.is_initialized():
1150
- raise RuntimeError(
1151
- "Parallel Muon requires torch.distributed to be initialized."
1152
- )
1153
-
1154
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
- for _, (names, params) in dtensor_group.items():
1156
- self.parallel(
1157
- names,
1158
- params,
1159
- group,
1160
- lr=lr,
1161
- weight_decay=weight_decay,
1162
- momentum=momentum,
1163
- qk_logits=qk_logits,
1164
- )
1165
-
1166
- if len(param_tensors) > 0:
1167
- self.base(
1168
- name_tensors,
1169
- param_tensors,
1170
- group,
1171
- lr=lr,
1172
- weight_decay=weight_decay,
1173
- momentum=momentum,
1174
- qk_logits=qk_logits,
1175
- )
1176
-
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
- @torch.no_grad
1246
- def step(self, closure=None, qk_logits=None):
1247
- """Perform a single optimization step.
1248
-
1249
- Args:
1250
- closure (Callable, optional): A closure that reevaluates the model
1251
- and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
- (1 / sqrt(head_dim)) * (Q @ K^T).
1256
- """
1257
- loss = None
1258
- if closure is not None:
1259
- with torch.enable_grad():
1260
- loss = closure()
1261
-
1262
- for group in self.param_groups:
1263
- if group["use_muon"]:
1264
- self._step_muon(group, qk_logits=qk_logits)
1265
- else:
1266
- self._step_adamw(group)
1267
-
1268
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py DELETED
@@ -1,175 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def get_slices_of_dtensor(
11
- target: DTensor | torch.Tensor,
12
- local_rank: int,
13
- shard_mesh: DeviceMesh,
14
- shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
- """
17
- Get the slice of local tensor for a given rank from a tensor.
18
- Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
- shard_placements (tuple[Placement]): The shard placements.
23
- """
24
-
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
-
27
- # find the global rank of the local rank in the shard mesh
28
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
-
30
- rank_coords = (shard_mesh.mesh == rank).nonzero()
31
-
32
- assert len(rank_coords) == 1
33
- rank_coords = tuple(rank_coords[0].tolist())
34
-
35
- assert len(rank_coords) == len(shard_placements)
36
-
37
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
- # left-to-right sharding. This is ensured by the sorting logic of
39
- # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
-
44
- num_ranks = shard_mesh.mesh.shape[i]
45
-
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
48
-
49
- if dim_size % num_ranks != 0:
50
- raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
63
-
64
- return tuple(slices)
65
-
66
-
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
- ProcessGroup]] = dict()
69
-
70
-
71
- def construct_shard_mesh(
72
- placements: tuple[Placement],
73
- mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
-
81
- assert mesh.mesh.device.type == 'cpu'
82
-
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
-
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
-
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
121
-
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
124
-
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
-
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
- else:
137
- shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
- assert len(shard_placements) == len(set(shard_placements))
142
-
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
- )
171
-
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,128 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- )
47
- @triton.jit
48
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
- GROUP_SIZE_M: tl.constexpr):
51
- """
52
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
- The code is a simple adaptation from the triton `matmul` tutorial:
54
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
- """
56
- pid = tl.program_id(axis=0)
57
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
- group_id = pid // num_pid_in_group
61
- first_pid_m = group_id * GROUP_SIZE_M
62
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
- pid_n = (pid % num_pid_in_group) // group_size_m
65
- if pid_m > pid_n:
66
- return
67
-
68
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_k = tl.arange(0, BLOCK_SIZE_K)
71
- # we use a & b ptrs to denote different rows of x.
72
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
-
75
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
-
77
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
- a = tl.load(a_ptrs,
79
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
- other=0.0)
81
- b = tl.load(b_ptrs,
82
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
- other=0.0)
84
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
- a_ptrs += BLOCK_SIZE_K * stride_xk
86
- b_ptrs += BLOCK_SIZE_K * stride_xk
87
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
- # https://github.com/triton-lang/triton/issues/2252
89
- c = accumulator.to(x.dtype.element_ty)
90
-
91
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
- tl.store(c_ptrs, c, mask=c_mask)
96
-
97
- # transpose and copy
98
- if pid_m < pid_n:
99
- ct_ptrs = y + stride_ym * offs_cn[:,
100
- None] + stride_yn * offs_cm[None, :]
101
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
-
104
-
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
- d_in = d_in.contiguous()
116
- M, K = d_in.shape
117
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
- M, META['BLOCK_SIZE_M']), )
119
- with torch.cuda.device(d_in.device.index):
120
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
- d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu128-x86_64-linux/muon.py DELETED
@@ -1,1268 +0,0 @@
1
- import logging
2
- import math
3
- import types
4
- from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
-
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
-
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
406
-
407
- assert inner_off == block
408
- off += block
409
-
410
-
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
- """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
- continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
-
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
522
-
523
- return None, -1
524
-
525
-
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
-
535
-
536
- class Muon(torch.optim.Optimizer):
537
- """
538
- Muon - MomentUm Orthogonalized by Newton-schulz
539
-
540
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
- the advantage that it can be stably run in bfloat16 on the GPU.
544
-
545
- Some warnings:
546
- - We believe this optimizer is unlikely to work well for training with small batch size.
547
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
-
549
- Arguments:
550
- model: The model to be optimized by Muon.
551
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
- weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
- adamw_lr: The learning rate for the internal AdamW.
559
- adamw_betas: The betas for the internal AdamW.
560
- adamw_eps: The epsilon for the internal AdamW.
561
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
- debug: Whether to print debug information.
563
- clip_info : Configuration for QK clipping. Expected keys:
564
- - "q_indices" (list[int]): Indices of query heads to consider.
565
- - "k_indices" (list[int]): Indices of key heads to consider.
566
- - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
- this value will be scaled down.
569
- Default is:
570
- {
571
- "q_indices": [],
572
- "k_indices": [],
573
- "head_dim": 128,
574
- "threshold": 100
575
- }
576
- warmup_step : How many all2all gather, compute operations are launched in advance
577
- before the corresponding all2all scatter steps begin.
578
- A higher warmup_step increases memory usage but can improve
579
- performance by overlapping communication.
580
- Parallel muon only.
581
- chunk_size : Batch size of parameters to process in each
582
- all2all gather/compute/scatter step.
583
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
- For testing purpose only.
586
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
- """
588
-
589
- def __init__(self,
590
- params,
591
- lr=1e-3,
592
- momentum=0.95,
593
- nesterov=True,
594
- ns_steps=5,
595
- weight_decay=0.1,
596
- adamw_betas=(0.9, 0.95),
597
- adamw_eps=1e-8,
598
- none_grad=True,
599
- debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
- warmup_step=5,
607
- chunk_size=-1,
608
- use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
610
- defaults = dict(
611
- lr=lr,
612
- weight_decay=weight_decay,
613
- momentum=momentum,
614
- nesterov=nesterov,
615
- ns_steps=ns_steps,
616
- adamw_betas=adamw_betas,
617
- adamw_eps=adamw_eps,
618
- none_grad=none_grad,
619
- use_muon=True,
620
- )
621
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
-
624
- if isinstance(params, types.GeneratorType):
625
- raise ValueError(error_message.format(idx=0) + instruction_code)
626
- for _idx, param_group in enumerate(params):
627
- if param_group.get("use_muon", None) is None:
628
- raise ValueError(
629
- error_message.format(idx=_idx) + instruction_code)
630
-
631
- super().__init__(params, defaults)
632
-
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
- self.debug = debug
638
- self.clip_config = clip_config
639
- self.warmup_step = warmup_step
640
- self.chunk_size = chunk_size
641
- self.use_distributed_muon = use_distributed_muon
642
- self.small_param_numel_threshold = small_param_numel_threshold
643
-
644
- def _calc_flops(self, G, steps):
645
- assert len(G.shape) == 2
646
- M, N = G.shape
647
- if M > N:
648
- M, N = N, M
649
-
650
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
-
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
- def get_shard_mesh(self, p):
667
- """
668
- Get the shard mesh for a parameter p on the given rank.
669
- """
670
- assert isinstance(
671
- p, DTensor), "Parallel Muon only supports DTensor parameters."
672
-
673
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
- p.placements, p.device_mesh)
675
-
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
- return shard_mesh, shard_pg, shard_placements
680
-
681
- def init_state_and_assign_params(self, names, params, group, qk_logits):
682
- param_to_state = {}
683
- param_to_flops = {}
684
-
685
- total_flops = 0
686
- for p in params:
687
- g = p.grad
688
- if g is None:
689
- continue
690
- assert g.ndim == 2, "Muon only supports 2D parameters."
691
-
692
- flops = self._calc_flops(g, group["ns_steps"])
693
- param_to_flops[id(p)] = flops
694
- total_flops += flops
695
-
696
- if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
-
700
- paired = list(zip(names, params))
701
-
702
- paired_sorted = sorted(paired,
703
- key=lambda x: param_to_flops[id(x[1])],
704
- reverse=True)
705
-
706
- names_sorted, params_sorted = zip(*paired_sorted)
707
- ordered_names = list(names_sorted)
708
- ordered_params = list(params_sorted)
709
-
710
- round_robin = 0
711
- mesh = ordered_params[0].device_mesh
712
- placements = ordered_params[0].placements
713
-
714
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
- ordered_params[0])
716
- shard_mesh_flattened = shard_mesh.mesh.flatten()
717
- num_ranks = dist.get_world_size(group=shard_pg)
718
-
719
- for n, p in zip(ordered_names, ordered_params):
720
- if mesh != p.device_mesh:
721
- raise ValueError("All parameters must be on the same mesh.")
722
- if placements != p.placements:
723
- raise ValueError("All parameters must have same placements.")
724
-
725
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
-
729
- param_to_state[id(p)] = _muon_state(
730
- worker_rank=worker_rank,
731
- process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
- name=n,
735
- qk_clip_state=qk_clip_state,
736
- )
737
-
738
- return param_to_state, ordered_params
739
-
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
- for n, p in zip(names, params):
744
- g = p.grad
745
- if g is None:
746
- continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
-
753
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
- steps=group["ns_steps"])
755
-
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
-
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
-
761
- scales_full = self._compute_scales(
762
- p, qk_clip_state) if qk_clip_state is not None else None
763
- if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
-
766
- def distributed_muon(
767
- self,
768
- names: list[str],
769
- params: list[torch.nn.Parameter],
770
- group: dict[str, Any],
771
- lr: float,
772
- weight_decay: float,
773
- momentum: float,
774
- qk_logits: list[torch.Tensor | DTensor] | None,
775
- ):
776
- """ Implementation of Distributed Muon by Liu et al. """
777
-
778
- for n, p in zip(names, params):
779
- g = p.grad
780
- if g is None:
781
- continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
-
788
- # Gather G
789
- if isinstance(p.data, DTensor):
790
- g_full = g.full_tensor()
791
- p_full = p.data.full_tensor()
792
- else:
793
- g_full = g
794
- p_full = p
795
-
796
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
- steps=group["ns_steps"])
798
-
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
-
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
-
804
- scales_full = self._compute_scales(
805
- p_full, qk_clip_state) if qk_clip_state is not None else None
806
-
807
- if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
-
810
- if isinstance(p.data, DTensor):
811
- ndims = len(p.device_mesh.mesh.shape)
812
- p_replicate = DTensor.from_local(
813
- p_full,
814
- device_mesh=p.device_mesh,
815
- placements=[Replicate() for _ in range(ndims)],
816
- )
817
-
818
- p_sharded = p_replicate.redistribute(
819
- device_mesh=p.device_mesh,
820
- placements=p.placements,
821
- )
822
-
823
- p.copy_(p_sharded)
824
-
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
- """
911
- Perform a parallel optimization step using Muon.
912
- """
913
-
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
-
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
-
930
- param_to_state, ordered_params = self.init_state_and_assign_params(
931
- names, params, group, qk_logits)
932
-
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
-
967
- if self.chunk_size == -1:
968
- shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
- elif self.chunk_size > 0:
972
- chunk_size = self.chunk_size
973
- else:
974
- raise ValueError("chunk_size must be -1 or a positive integer.")
975
-
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
-
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
-
1064
- def _step_muon(self, group, qk_logits=None):
1065
- params = group["params"]
1066
- lr = group["lr"]
1067
- weight_decay = group["weight_decay"]
1068
- momentum = group["momentum"]
1069
- names = group["names"]
1070
-
1071
- param_dtensors = []
1072
- name_dtensors = []
1073
-
1074
- param_tensors = []
1075
- name_tensors = []
1076
-
1077
- param_dtensors_small = []
1078
- name_dtensors_small = []
1079
-
1080
- if self.use_distributed_muon:
1081
- self.distributed_muon(names=names,
1082
- params=params,
1083
- group=group,
1084
- lr=lr,
1085
- weight_decay=weight_decay,
1086
- momentum=momentum,
1087
- qk_logits=qk_logits)
1088
- return
1089
-
1090
- # For simplicity, we use distributed Muon for small parameters
1091
- # whose number of elements is below a threshold.
1092
- for n, p in zip(names, params):
1093
- if p is None or p.grad is None:
1094
- continue
1095
- if isinstance(p.data, DTensor):
1096
- if all(
1097
- isinstance(placement, Replicate)
1098
- for placement in p.placements):
1099
- param_tensors.append(p)
1100
- name_tensors.append(n)
1101
- elif p.data.numel() <= self.small_param_numel_threshold:
1102
- param_dtensors_small.append(p)
1103
- name_dtensors_small.append(n)
1104
- else:
1105
- param_dtensors.append(p)
1106
- name_dtensors.append(n)
1107
- elif isinstance(p.data, torch.Tensor):
1108
- param_tensors.append(p)
1109
- name_tensors.append(n)
1110
- else:
1111
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
-
1113
- logger.debug(
1114
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
- f"{len(param_dtensors_small)} Small DTensors")
1116
-
1117
- def group_dtensors(dtensors, names):
1118
- # To support different placements, we group parameters by placements
1119
- # and run parallel Muon on each group.
1120
-
1121
- placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
-
1124
- assert len(dtensors) == len(names)
1125
- for p, n in zip(dtensors, names):
1126
- placement_to_params[tuple([p.placements,
1127
- p.device_mesh])][0].append(n)
1128
- placement_to_params[tuple([p.placements,
1129
- p.device_mesh])][1].append(p)
1130
- return placement_to_params
1131
-
1132
- if len(param_dtensors_small) > 0:
1133
- if not dist.is_initialized():
1134
- raise RuntimeError(
1135
- "Parallel Muon requires torch.distributed to be initialized."
1136
- )
1137
-
1138
- self.distributed_muon(
1139
- params=param_dtensors_small,
1140
- names=name_dtensors_small,
1141
- group=group,
1142
- lr=lr,
1143
- weight_decay=weight_decay,
1144
- momentum=momentum,
1145
- qk_logits=qk_logits,
1146
- )
1147
-
1148
- if len(param_dtensors) > 0:
1149
- if not dist.is_initialized():
1150
- raise RuntimeError(
1151
- "Parallel Muon requires torch.distributed to be initialized."
1152
- )
1153
-
1154
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
- for _, (names, params) in dtensor_group.items():
1156
- self.parallel(
1157
- names,
1158
- params,
1159
- group,
1160
- lr=lr,
1161
- weight_decay=weight_decay,
1162
- momentum=momentum,
1163
- qk_logits=qk_logits,
1164
- )
1165
-
1166
- if len(param_tensors) > 0:
1167
- self.base(
1168
- name_tensors,
1169
- param_tensors,
1170
- group,
1171
- lr=lr,
1172
- weight_decay=weight_decay,
1173
- momentum=momentum,
1174
- qk_logits=qk_logits,
1175
- )
1176
-
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
- @torch.no_grad
1246
- def step(self, closure=None, qk_logits=None):
1247
- """Perform a single optimization step.
1248
-
1249
- Args:
1250
- closure (Callable, optional): A closure that reevaluates the model
1251
- and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
- (1 / sqrt(head_dim)) * (Q @ K^T).
1256
- """
1257
- loss = None
1258
- if closure is not None:
1259
- with torch.enable_grad():
1260
- loss = closure()
1261
-
1262
- for group in self.param_groups:
1263
- if group["use_muon"]:
1264
- self._step_muon(group, qk_logits=qk_logits)
1265
- else:
1266
- self._step_adamw(group)
1267
-
1268
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py DELETED
@@ -1,175 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def get_slices_of_dtensor(
11
- target: DTensor | torch.Tensor,
12
- local_rank: int,
13
- shard_mesh: DeviceMesh,
14
- shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
- """
17
- Get the slice of local tensor for a given rank from a tensor.
18
- Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
- shard_placements (tuple[Placement]): The shard placements.
23
- """
24
-
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
-
27
- # find the global rank of the local rank in the shard mesh
28
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
-
30
- rank_coords = (shard_mesh.mesh == rank).nonzero()
31
-
32
- assert len(rank_coords) == 1
33
- rank_coords = tuple(rank_coords[0].tolist())
34
-
35
- assert len(rank_coords) == len(shard_placements)
36
-
37
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
- # left-to-right sharding. This is ensured by the sorting logic of
39
- # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
-
44
- num_ranks = shard_mesh.mesh.shape[i]
45
-
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
48
-
49
- if dim_size % num_ranks != 0:
50
- raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
63
-
64
- return tuple(slices)
65
-
66
-
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
- ProcessGroup]] = dict()
69
-
70
-
71
- def construct_shard_mesh(
72
- placements: tuple[Placement],
73
- mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
-
81
- assert mesh.mesh.device.type == 'cpu'
82
-
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
-
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
-
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
121
-
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
124
-
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
-
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
- else:
137
- shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
- assert len(shard_placements) == len(set(shard_placements))
142
-
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
- )
171
-
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,128 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- )
47
- @triton.jit
48
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
- GROUP_SIZE_M: tl.constexpr):
51
- """
52
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
- The code is a simple adaptation from the triton `matmul` tutorial:
54
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
- """
56
- pid = tl.program_id(axis=0)
57
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
- group_id = pid // num_pid_in_group
61
- first_pid_m = group_id * GROUP_SIZE_M
62
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
- pid_n = (pid % num_pid_in_group) // group_size_m
65
- if pid_m > pid_n:
66
- return
67
-
68
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_k = tl.arange(0, BLOCK_SIZE_K)
71
- # we use a & b ptrs to denote different rows of x.
72
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
-
75
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
-
77
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
- a = tl.load(a_ptrs,
79
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
- other=0.0)
81
- b = tl.load(b_ptrs,
82
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
- other=0.0)
84
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
- a_ptrs += BLOCK_SIZE_K * stride_xk
86
- b_ptrs += BLOCK_SIZE_K * stride_xk
87
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
- # https://github.com/triton-lang/triton/issues/2252
89
- c = accumulator.to(x.dtype.element_ty)
90
-
91
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
- tl.store(c_ptrs, c, mask=c_mask)
96
-
97
- # transpose and copy
98
- if pid_m < pid_n:
99
- ct_ptrs = y + stride_ym * offs_cn[:,
100
- None] + stride_yn * offs_cm[None, :]
101
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
-
104
-
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
- d_in = d_in.contiguous()
116
- M, K = d_in.shape
117
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
- M, META['BLOCK_SIZE_M']), )
119
- with torch.cuda.device(d_in.device.index):
120
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
- d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu130-x86_64-linux/muon.py DELETED
@@ -1,1268 +0,0 @@
1
- import logging
2
- import math
3
- import types
4
- from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
-
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
-
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
406
-
407
- assert inner_off == block
408
- off += block
409
-
410
-
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
- """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
- continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
-
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
522
-
523
- return None, -1
524
-
525
-
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
-
535
-
536
- class Muon(torch.optim.Optimizer):
537
- """
538
- Muon - MomentUm Orthogonalized by Newton-schulz
539
-
540
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
- the advantage that it can be stably run in bfloat16 on the GPU.
544
-
545
- Some warnings:
546
- - We believe this optimizer is unlikely to work well for training with small batch size.
547
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
-
549
- Arguments:
550
- model: The model to be optimized by Muon.
551
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
- weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
- adamw_lr: The learning rate for the internal AdamW.
559
- adamw_betas: The betas for the internal AdamW.
560
- adamw_eps: The epsilon for the internal AdamW.
561
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
- debug: Whether to print debug information.
563
- clip_info : Configuration for QK clipping. Expected keys:
564
- - "q_indices" (list[int]): Indices of query heads to consider.
565
- - "k_indices" (list[int]): Indices of key heads to consider.
566
- - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
- this value will be scaled down.
569
- Default is:
570
- {
571
- "q_indices": [],
572
- "k_indices": [],
573
- "head_dim": 128,
574
- "threshold": 100
575
- }
576
- warmup_step : How many all2all gather, compute operations are launched in advance
577
- before the corresponding all2all scatter steps begin.
578
- A higher warmup_step increases memory usage but can improve
579
- performance by overlapping communication.
580
- Parallel muon only.
581
- chunk_size : Batch size of parameters to process in each
582
- all2all gather/compute/scatter step.
583
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
- For testing purpose only.
586
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
- """
588
-
589
- def __init__(self,
590
- params,
591
- lr=1e-3,
592
- momentum=0.95,
593
- nesterov=True,
594
- ns_steps=5,
595
- weight_decay=0.1,
596
- adamw_betas=(0.9, 0.95),
597
- adamw_eps=1e-8,
598
- none_grad=True,
599
- debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
- warmup_step=5,
607
- chunk_size=-1,
608
- use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
610
- defaults = dict(
611
- lr=lr,
612
- weight_decay=weight_decay,
613
- momentum=momentum,
614
- nesterov=nesterov,
615
- ns_steps=ns_steps,
616
- adamw_betas=adamw_betas,
617
- adamw_eps=adamw_eps,
618
- none_grad=none_grad,
619
- use_muon=True,
620
- )
621
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
-
624
- if isinstance(params, types.GeneratorType):
625
- raise ValueError(error_message.format(idx=0) + instruction_code)
626
- for _idx, param_group in enumerate(params):
627
- if param_group.get("use_muon", None) is None:
628
- raise ValueError(
629
- error_message.format(idx=_idx) + instruction_code)
630
-
631
- super().__init__(params, defaults)
632
-
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
- self.debug = debug
638
- self.clip_config = clip_config
639
- self.warmup_step = warmup_step
640
- self.chunk_size = chunk_size
641
- self.use_distributed_muon = use_distributed_muon
642
- self.small_param_numel_threshold = small_param_numel_threshold
643
-
644
- def _calc_flops(self, G, steps):
645
- assert len(G.shape) == 2
646
- M, N = G.shape
647
- if M > N:
648
- M, N = N, M
649
-
650
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
-
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
- def get_shard_mesh(self, p):
667
- """
668
- Get the shard mesh for a parameter p on the given rank.
669
- """
670
- assert isinstance(
671
- p, DTensor), "Parallel Muon only supports DTensor parameters."
672
-
673
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
- p.placements, p.device_mesh)
675
-
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
- return shard_mesh, shard_pg, shard_placements
680
-
681
- def init_state_and_assign_params(self, names, params, group, qk_logits):
682
- param_to_state = {}
683
- param_to_flops = {}
684
-
685
- total_flops = 0
686
- for p in params:
687
- g = p.grad
688
- if g is None:
689
- continue
690
- assert g.ndim == 2, "Muon only supports 2D parameters."
691
-
692
- flops = self._calc_flops(g, group["ns_steps"])
693
- param_to_flops[id(p)] = flops
694
- total_flops += flops
695
-
696
- if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
-
700
- paired = list(zip(names, params))
701
-
702
- paired_sorted = sorted(paired,
703
- key=lambda x: param_to_flops[id(x[1])],
704
- reverse=True)
705
-
706
- names_sorted, params_sorted = zip(*paired_sorted)
707
- ordered_names = list(names_sorted)
708
- ordered_params = list(params_sorted)
709
-
710
- round_robin = 0
711
- mesh = ordered_params[0].device_mesh
712
- placements = ordered_params[0].placements
713
-
714
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
- ordered_params[0])
716
- shard_mesh_flattened = shard_mesh.mesh.flatten()
717
- num_ranks = dist.get_world_size(group=shard_pg)
718
-
719
- for n, p in zip(ordered_names, ordered_params):
720
- if mesh != p.device_mesh:
721
- raise ValueError("All parameters must be on the same mesh.")
722
- if placements != p.placements:
723
- raise ValueError("All parameters must have same placements.")
724
-
725
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
-
729
- param_to_state[id(p)] = _muon_state(
730
- worker_rank=worker_rank,
731
- process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
- name=n,
735
- qk_clip_state=qk_clip_state,
736
- )
737
-
738
- return param_to_state, ordered_params
739
-
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
- for n, p in zip(names, params):
744
- g = p.grad
745
- if g is None:
746
- continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
-
753
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
- steps=group["ns_steps"])
755
-
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
-
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
-
761
- scales_full = self._compute_scales(
762
- p, qk_clip_state) if qk_clip_state is not None else None
763
- if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
-
766
- def distributed_muon(
767
- self,
768
- names: list[str],
769
- params: list[torch.nn.Parameter],
770
- group: dict[str, Any],
771
- lr: float,
772
- weight_decay: float,
773
- momentum: float,
774
- qk_logits: list[torch.Tensor | DTensor] | None,
775
- ):
776
- """ Implementation of Distributed Muon by Liu et al. """
777
-
778
- for n, p in zip(names, params):
779
- g = p.grad
780
- if g is None:
781
- continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
-
788
- # Gather G
789
- if isinstance(p.data, DTensor):
790
- g_full = g.full_tensor()
791
- p_full = p.data.full_tensor()
792
- else:
793
- g_full = g
794
- p_full = p
795
-
796
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
- steps=group["ns_steps"])
798
-
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
-
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
-
804
- scales_full = self._compute_scales(
805
- p_full, qk_clip_state) if qk_clip_state is not None else None
806
-
807
- if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
-
810
- if isinstance(p.data, DTensor):
811
- ndims = len(p.device_mesh.mesh.shape)
812
- p_replicate = DTensor.from_local(
813
- p_full,
814
- device_mesh=p.device_mesh,
815
- placements=[Replicate() for _ in range(ndims)],
816
- )
817
-
818
- p_sharded = p_replicate.redistribute(
819
- device_mesh=p.device_mesh,
820
- placements=p.placements,
821
- )
822
-
823
- p.copy_(p_sharded)
824
-
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
- """
911
- Perform a parallel optimization step using Muon.
912
- """
913
-
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
-
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
-
930
- param_to_state, ordered_params = self.init_state_and_assign_params(
931
- names, params, group, qk_logits)
932
-
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
-
967
- if self.chunk_size == -1:
968
- shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
- elif self.chunk_size > 0:
972
- chunk_size = self.chunk_size
973
- else:
974
- raise ValueError("chunk_size must be -1 or a positive integer.")
975
-
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
-
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
-
1064
- def _step_muon(self, group, qk_logits=None):
1065
- params = group["params"]
1066
- lr = group["lr"]
1067
- weight_decay = group["weight_decay"]
1068
- momentum = group["momentum"]
1069
- names = group["names"]
1070
-
1071
- param_dtensors = []
1072
- name_dtensors = []
1073
-
1074
- param_tensors = []
1075
- name_tensors = []
1076
-
1077
- param_dtensors_small = []
1078
- name_dtensors_small = []
1079
-
1080
- if self.use_distributed_muon:
1081
- self.distributed_muon(names=names,
1082
- params=params,
1083
- group=group,
1084
- lr=lr,
1085
- weight_decay=weight_decay,
1086
- momentum=momentum,
1087
- qk_logits=qk_logits)
1088
- return
1089
-
1090
- # For simplicity, we use distributed Muon for small parameters
1091
- # whose number of elements is below a threshold.
1092
- for n, p in zip(names, params):
1093
- if p is None or p.grad is None:
1094
- continue
1095
- if isinstance(p.data, DTensor):
1096
- if all(
1097
- isinstance(placement, Replicate)
1098
- for placement in p.placements):
1099
- param_tensors.append(p)
1100
- name_tensors.append(n)
1101
- elif p.data.numel() <= self.small_param_numel_threshold:
1102
- param_dtensors_small.append(p)
1103
- name_dtensors_small.append(n)
1104
- else:
1105
- param_dtensors.append(p)
1106
- name_dtensors.append(n)
1107
- elif isinstance(p.data, torch.Tensor):
1108
- param_tensors.append(p)
1109
- name_tensors.append(n)
1110
- else:
1111
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
-
1113
- logger.debug(
1114
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
- f"{len(param_dtensors_small)} Small DTensors")
1116
-
1117
- def group_dtensors(dtensors, names):
1118
- # To support different placements, we group parameters by placements
1119
- # and run parallel Muon on each group.
1120
-
1121
- placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
-
1124
- assert len(dtensors) == len(names)
1125
- for p, n in zip(dtensors, names):
1126
- placement_to_params[tuple([p.placements,
1127
- p.device_mesh])][0].append(n)
1128
- placement_to_params[tuple([p.placements,
1129
- p.device_mesh])][1].append(p)
1130
- return placement_to_params
1131
-
1132
- if len(param_dtensors_small) > 0:
1133
- if not dist.is_initialized():
1134
- raise RuntimeError(
1135
- "Parallel Muon requires torch.distributed to be initialized."
1136
- )
1137
-
1138
- self.distributed_muon(
1139
- params=param_dtensors_small,
1140
- names=name_dtensors_small,
1141
- group=group,
1142
- lr=lr,
1143
- weight_decay=weight_decay,
1144
- momentum=momentum,
1145
- qk_logits=qk_logits,
1146
- )
1147
-
1148
- if len(param_dtensors) > 0:
1149
- if not dist.is_initialized():
1150
- raise RuntimeError(
1151
- "Parallel Muon requires torch.distributed to be initialized."
1152
- )
1153
-
1154
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
- for _, (names, params) in dtensor_group.items():
1156
- self.parallel(
1157
- names,
1158
- params,
1159
- group,
1160
- lr=lr,
1161
- weight_decay=weight_decay,
1162
- momentum=momentum,
1163
- qk_logits=qk_logits,
1164
- )
1165
-
1166
- if len(param_tensors) > 0:
1167
- self.base(
1168
- name_tensors,
1169
- param_tensors,
1170
- group,
1171
- lr=lr,
1172
- weight_decay=weight_decay,
1173
- momentum=momentum,
1174
- qk_logits=qk_logits,
1175
- )
1176
-
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
- @torch.no_grad
1246
- def step(self, closure=None, qk_logits=None):
1247
- """Perform a single optimization step.
1248
-
1249
- Args:
1250
- closure (Callable, optional): A closure that reevaluates the model
1251
- and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
- (1 / sqrt(head_dim)) * (Q @ K^T).
1256
- """
1257
- loss = None
1258
- if closure is not None:
1259
- with torch.enable_grad():
1260
- loss = closure()
1261
-
1262
- for group in self.param_groups:
1263
- if group["use_muon"]:
1264
- self._step_muon(group, qk_logits=qk_logits)
1265
- else:
1266
- self._step_adamw(group)
1267
-
1268
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py DELETED
@@ -1,175 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def get_slices_of_dtensor(
11
- target: DTensor | torch.Tensor,
12
- local_rank: int,
13
- shard_mesh: DeviceMesh,
14
- shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
- """
17
- Get the slice of local tensor for a given rank from a tensor.
18
- Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
- shard_placements (tuple[Placement]): The shard placements.
23
- """
24
-
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
-
27
- # find the global rank of the local rank in the shard mesh
28
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
-
30
- rank_coords = (shard_mesh.mesh == rank).nonzero()
31
-
32
- assert len(rank_coords) == 1
33
- rank_coords = tuple(rank_coords[0].tolist())
34
-
35
- assert len(rank_coords) == len(shard_placements)
36
-
37
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
- # left-to-right sharding. This is ensured by the sorting logic of
39
- # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
-
44
- num_ranks = shard_mesh.mesh.shape[i]
45
-
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
48
-
49
- if dim_size % num_ranks != 0:
50
- raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
63
-
64
- return tuple(slices)
65
-
66
-
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
- ProcessGroup]] = dict()
69
-
70
-
71
- def construct_shard_mesh(
72
- placements: tuple[Placement],
73
- mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
-
81
- assert mesh.mesh.device.type == 'cpu'
82
-
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
-
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
-
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
121
-
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
124
-
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
-
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
- else:
137
- shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
- assert len(shard_placements) == len(set(shard_placements))
142
-
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
- )
171
-
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,128 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- )
47
- @triton.jit
48
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
- GROUP_SIZE_M: tl.constexpr):
51
- """
52
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
- The code is a simple adaptation from the triton `matmul` tutorial:
54
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
- """
56
- pid = tl.program_id(axis=0)
57
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
- group_id = pid // num_pid_in_group
61
- first_pid_m = group_id * GROUP_SIZE_M
62
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
- pid_n = (pid % num_pid_in_group) // group_size_m
65
- if pid_m > pid_n:
66
- return
67
-
68
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_k = tl.arange(0, BLOCK_SIZE_K)
71
- # we use a & b ptrs to denote different rows of x.
72
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
-
75
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
-
77
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
- a = tl.load(a_ptrs,
79
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
- other=0.0)
81
- b = tl.load(b_ptrs,
82
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
- other=0.0)
84
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
- a_ptrs += BLOCK_SIZE_K * stride_xk
86
- b_ptrs += BLOCK_SIZE_K * stride_xk
87
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
- # https://github.com/triton-lang/triton/issues/2252
89
- c = accumulator.to(x.dtype.element_ty)
90
-
91
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
- tl.store(c_ptrs, c, mask=c_mask)
96
-
97
- # transpose and copy
98
- if pid_m < pid_n:
99
- ct_ptrs = y + stride_ym * offs_cn[:,
100
- None] + stride_yn * offs_cm[None, :]
101
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
-
104
-
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
- d_in = d_in.contiguous()
116
- M, K = d_in.shape
117
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
- M, META['BLOCK_SIZE_M']), )
119
- with torch.cuda.device(d_in.device.index):
120
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
- d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-rocm70-x86_64-linux/muon.py DELETED
@@ -1,1268 +0,0 @@
1
- import logging
2
- import math
3
- import types
4
- from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
-
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
-
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
406
-
407
- assert inner_off == block
408
- off += block
409
-
410
-
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
- """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
- continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
-
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
522
-
523
- return None, -1
524
-
525
-
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
-
535
-
536
- class Muon(torch.optim.Optimizer):
537
- """
538
- Muon - MomentUm Orthogonalized by Newton-schulz
539
-
540
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
- the advantage that it can be stably run in bfloat16 on the GPU.
544
-
545
- Some warnings:
546
- - We believe this optimizer is unlikely to work well for training with small batch size.
547
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
-
549
- Arguments:
550
- model: The model to be optimized by Muon.
551
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
- weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
- adamw_lr: The learning rate for the internal AdamW.
559
- adamw_betas: The betas for the internal AdamW.
560
- adamw_eps: The epsilon for the internal AdamW.
561
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
- debug: Whether to print debug information.
563
- clip_info : Configuration for QK clipping. Expected keys:
564
- - "q_indices" (list[int]): Indices of query heads to consider.
565
- - "k_indices" (list[int]): Indices of key heads to consider.
566
- - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
- this value will be scaled down.
569
- Default is:
570
- {
571
- "q_indices": [],
572
- "k_indices": [],
573
- "head_dim": 128,
574
- "threshold": 100
575
- }
576
- warmup_step : How many all2all gather, compute operations are launched in advance
577
- before the corresponding all2all scatter steps begin.
578
- A higher warmup_step increases memory usage but can improve
579
- performance by overlapping communication.
580
- Parallel muon only.
581
- chunk_size : Batch size of parameters to process in each
582
- all2all gather/compute/scatter step.
583
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
- For testing purpose only.
586
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
- """
588
-
589
- def __init__(self,
590
- params,
591
- lr=1e-3,
592
- momentum=0.95,
593
- nesterov=True,
594
- ns_steps=5,
595
- weight_decay=0.1,
596
- adamw_betas=(0.9, 0.95),
597
- adamw_eps=1e-8,
598
- none_grad=True,
599
- debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
- warmup_step=5,
607
- chunk_size=-1,
608
- use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
610
- defaults = dict(
611
- lr=lr,
612
- weight_decay=weight_decay,
613
- momentum=momentum,
614
- nesterov=nesterov,
615
- ns_steps=ns_steps,
616
- adamw_betas=adamw_betas,
617
- adamw_eps=adamw_eps,
618
- none_grad=none_grad,
619
- use_muon=True,
620
- )
621
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
-
624
- if isinstance(params, types.GeneratorType):
625
- raise ValueError(error_message.format(idx=0) + instruction_code)
626
- for _idx, param_group in enumerate(params):
627
- if param_group.get("use_muon", None) is None:
628
- raise ValueError(
629
- error_message.format(idx=_idx) + instruction_code)
630
-
631
- super().__init__(params, defaults)
632
-
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
- self.debug = debug
638
- self.clip_config = clip_config
639
- self.warmup_step = warmup_step
640
- self.chunk_size = chunk_size
641
- self.use_distributed_muon = use_distributed_muon
642
- self.small_param_numel_threshold = small_param_numel_threshold
643
-
644
- def _calc_flops(self, G, steps):
645
- assert len(G.shape) == 2
646
- M, N = G.shape
647
- if M > N:
648
- M, N = N, M
649
-
650
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
-
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
- def get_shard_mesh(self, p):
667
- """
668
- Get the shard mesh for a parameter p on the given rank.
669
- """
670
- assert isinstance(
671
- p, DTensor), "Parallel Muon only supports DTensor parameters."
672
-
673
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
- p.placements, p.device_mesh)
675
-
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
- return shard_mesh, shard_pg, shard_placements
680
-
681
- def init_state_and_assign_params(self, names, params, group, qk_logits):
682
- param_to_state = {}
683
- param_to_flops = {}
684
-
685
- total_flops = 0
686
- for p in params:
687
- g = p.grad
688
- if g is None:
689
- continue
690
- assert g.ndim == 2, "Muon only supports 2D parameters."
691
-
692
- flops = self._calc_flops(g, group["ns_steps"])
693
- param_to_flops[id(p)] = flops
694
- total_flops += flops
695
-
696
- if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
-
700
- paired = list(zip(names, params))
701
-
702
- paired_sorted = sorted(paired,
703
- key=lambda x: param_to_flops[id(x[1])],
704
- reverse=True)
705
-
706
- names_sorted, params_sorted = zip(*paired_sorted)
707
- ordered_names = list(names_sorted)
708
- ordered_params = list(params_sorted)
709
-
710
- round_robin = 0
711
- mesh = ordered_params[0].device_mesh
712
- placements = ordered_params[0].placements
713
-
714
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
- ordered_params[0])
716
- shard_mesh_flattened = shard_mesh.mesh.flatten()
717
- num_ranks = dist.get_world_size(group=shard_pg)
718
-
719
- for n, p in zip(ordered_names, ordered_params):
720
- if mesh != p.device_mesh:
721
- raise ValueError("All parameters must be on the same mesh.")
722
- if placements != p.placements:
723
- raise ValueError("All parameters must have same placements.")
724
-
725
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
-
729
- param_to_state[id(p)] = _muon_state(
730
- worker_rank=worker_rank,
731
- process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
- name=n,
735
- qk_clip_state=qk_clip_state,
736
- )
737
-
738
- return param_to_state, ordered_params
739
-
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
- for n, p in zip(names, params):
744
- g = p.grad
745
- if g is None:
746
- continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
-
753
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
- steps=group["ns_steps"])
755
-
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
-
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
-
761
- scales_full = self._compute_scales(
762
- p, qk_clip_state) if qk_clip_state is not None else None
763
- if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
-
766
- def distributed_muon(
767
- self,
768
- names: list[str],
769
- params: list[torch.nn.Parameter],
770
- group: dict[str, Any],
771
- lr: float,
772
- weight_decay: float,
773
- momentum: float,
774
- qk_logits: list[torch.Tensor | DTensor] | None,
775
- ):
776
- """ Implementation of Distributed Muon by Liu et al. """
777
-
778
- for n, p in zip(names, params):
779
- g = p.grad
780
- if g is None:
781
- continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
-
788
- # Gather G
789
- if isinstance(p.data, DTensor):
790
- g_full = g.full_tensor()
791
- p_full = p.data.full_tensor()
792
- else:
793
- g_full = g
794
- p_full = p
795
-
796
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
- steps=group["ns_steps"])
798
-
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
-
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
-
804
- scales_full = self._compute_scales(
805
- p_full, qk_clip_state) if qk_clip_state is not None else None
806
-
807
- if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
-
810
- if isinstance(p.data, DTensor):
811
- ndims = len(p.device_mesh.mesh.shape)
812
- p_replicate = DTensor.from_local(
813
- p_full,
814
- device_mesh=p.device_mesh,
815
- placements=[Replicate() for _ in range(ndims)],
816
- )
817
-
818
- p_sharded = p_replicate.redistribute(
819
- device_mesh=p.device_mesh,
820
- placements=p.placements,
821
- )
822
-
823
- p.copy_(p_sharded)
824
-
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
- """
911
- Perform a parallel optimization step using Muon.
912
- """
913
-
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
-
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
-
930
- param_to_state, ordered_params = self.init_state_and_assign_params(
931
- names, params, group, qk_logits)
932
-
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
-
967
- if self.chunk_size == -1:
968
- shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
- elif self.chunk_size > 0:
972
- chunk_size = self.chunk_size
973
- else:
974
- raise ValueError("chunk_size must be -1 or a positive integer.")
975
-
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
-
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
-
1064
- def _step_muon(self, group, qk_logits=None):
1065
- params = group["params"]
1066
- lr = group["lr"]
1067
- weight_decay = group["weight_decay"]
1068
- momentum = group["momentum"]
1069
- names = group["names"]
1070
-
1071
- param_dtensors = []
1072
- name_dtensors = []
1073
-
1074
- param_tensors = []
1075
- name_tensors = []
1076
-
1077
- param_dtensors_small = []
1078
- name_dtensors_small = []
1079
-
1080
- if self.use_distributed_muon:
1081
- self.distributed_muon(names=names,
1082
- params=params,
1083
- group=group,
1084
- lr=lr,
1085
- weight_decay=weight_decay,
1086
- momentum=momentum,
1087
- qk_logits=qk_logits)
1088
- return
1089
-
1090
- # For simplicity, we use distributed Muon for small parameters
1091
- # whose number of elements is below a threshold.
1092
- for n, p in zip(names, params):
1093
- if p is None or p.grad is None:
1094
- continue
1095
- if isinstance(p.data, DTensor):
1096
- if all(
1097
- isinstance(placement, Replicate)
1098
- for placement in p.placements):
1099
- param_tensors.append(p)
1100
- name_tensors.append(n)
1101
- elif p.data.numel() <= self.small_param_numel_threshold:
1102
- param_dtensors_small.append(p)
1103
- name_dtensors_small.append(n)
1104
- else:
1105
- param_dtensors.append(p)
1106
- name_dtensors.append(n)
1107
- elif isinstance(p.data, torch.Tensor):
1108
- param_tensors.append(p)
1109
- name_tensors.append(n)
1110
- else:
1111
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
-
1113
- logger.debug(
1114
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
- f"{len(param_dtensors_small)} Small DTensors")
1116
-
1117
- def group_dtensors(dtensors, names):
1118
- # To support different placements, we group parameters by placements
1119
- # and run parallel Muon on each group.
1120
-
1121
- placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
-
1124
- assert len(dtensors) == len(names)
1125
- for p, n in zip(dtensors, names):
1126
- placement_to_params[tuple([p.placements,
1127
- p.device_mesh])][0].append(n)
1128
- placement_to_params[tuple([p.placements,
1129
- p.device_mesh])][1].append(p)
1130
- return placement_to_params
1131
-
1132
- if len(param_dtensors_small) > 0:
1133
- if not dist.is_initialized():
1134
- raise RuntimeError(
1135
- "Parallel Muon requires torch.distributed to be initialized."
1136
- )
1137
-
1138
- self.distributed_muon(
1139
- params=param_dtensors_small,
1140
- names=name_dtensors_small,
1141
- group=group,
1142
- lr=lr,
1143
- weight_decay=weight_decay,
1144
- momentum=momentum,
1145
- qk_logits=qk_logits,
1146
- )
1147
-
1148
- if len(param_dtensors) > 0:
1149
- if not dist.is_initialized():
1150
- raise RuntimeError(
1151
- "Parallel Muon requires torch.distributed to be initialized."
1152
- )
1153
-
1154
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
- for _, (names, params) in dtensor_group.items():
1156
- self.parallel(
1157
- names,
1158
- params,
1159
- group,
1160
- lr=lr,
1161
- weight_decay=weight_decay,
1162
- momentum=momentum,
1163
- qk_logits=qk_logits,
1164
- )
1165
-
1166
- if len(param_tensors) > 0:
1167
- self.base(
1168
- name_tensors,
1169
- param_tensors,
1170
- group,
1171
- lr=lr,
1172
- weight_decay=weight_decay,
1173
- momentum=momentum,
1174
- qk_logits=qk_logits,
1175
- )
1176
-
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
- @torch.no_grad
1246
- def step(self, closure=None, qk_logits=None):
1247
- """Perform a single optimization step.
1248
-
1249
- Args:
1250
- closure (Callable, optional): A closure that reevaluates the model
1251
- and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
- (1 / sqrt(head_dim)) * (Q @ K^T).
1256
- """
1257
- loss = None
1258
- if closure is not None:
1259
- with torch.enable_grad():
1260
- loss = closure()
1261
-
1262
- for group in self.param_groups:
1263
- if group["use_muon"]:
1264
- self._step_muon(group, qk_logits=qk_logits)
1265
- else:
1266
- self._step_adamw(group)
1267
-
1268
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d804ba4d3ed9716c80e9819ba16a2bef300fb23fa4c456c550f4a96167a2eb00
3
- size 1866112
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py DELETED
@@ -1,175 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- from torch.distributed import ProcessGroup
4
- from torch.distributed.device_mesh import DeviceMesh
5
- from torch.distributed.tensor import DTensor
6
- from torch.distributed.tensor.placement_types import (Placement, Shard,
7
- _StridedShard)
8
-
9
-
10
- def get_slices_of_dtensor(
11
- target: DTensor | torch.Tensor,
12
- local_rank: int,
13
- shard_mesh: DeviceMesh,
14
- shard_placements: tuple[Placement],
15
- ) -> tuple[slice]:
16
- """
17
- Get the slice of local tensor for a given rank from a tensor.
18
- Args:
19
- target (DTensor | torch.Tensor): The target tensor.
20
- rank (int): The local rank of the shard group.
21
- shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
- shard_placements (tuple[Placement]): The shard placements.
23
- """
24
-
25
- slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
-
27
- # find the global rank of the local rank in the shard mesh
28
- rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
-
30
- rank_coords = (shard_mesh.mesh == rank).nonzero()
31
-
32
- assert len(rank_coords) == 1
33
- rank_coords = tuple(rank_coords[0].tolist())
34
-
35
- assert len(rank_coords) == len(shard_placements)
36
-
37
- # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
- # left-to-right sharding. This is ensured by the sorting logic of
39
- # construct_shard_mesh function.
40
- for i, (rank_coord,
41
- placement) in enumerate(zip(rank_coords, shard_placements)):
42
- assert isinstance(placement, Shard)
43
-
44
- num_ranks = shard_mesh.mesh.shape[i]
45
-
46
- dim = placement.dim
47
- dim_size = (slices[dim].stop - slices[dim].start)
48
-
49
- if dim_size % num_ranks != 0:
50
- raise NotImplementedError(
51
- f"Dimension size {dim_size} is not divisible "
52
- f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}. (shape: {target.shape})")
54
-
55
- shard_size = dim_size // num_ranks
56
-
57
- start = slices[dim].start + rank_coord * shard_size
58
- end = start + shard_size
59
-
60
- assert start < end <= slices[dim].stop
61
-
62
- slices[dim] = slice(start, end)
63
-
64
- return tuple(slices)
65
-
66
-
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
- ProcessGroup]] = dict()
69
-
70
-
71
- def construct_shard_mesh(
72
- placements: tuple[Placement],
73
- mesh: DeviceMesh,
74
- ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
- """
76
- Construct Shard Mesh and Placements for unsharding.
77
- It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
- """
79
- my_rank = dist.get_rank()
80
-
81
- assert mesh.mesh.device.type == 'cpu'
82
-
83
- # Copy mesh to avoid modifying the original mesh
84
- mesh = mesh.mesh.clone()
85
-
86
- # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
-
88
- # For Shard, strided shard comes after regular shard on the same dim
89
- # to preserve left-to-right order of replicate-to-shard.
90
- # This is because that strided shard is using stride to represent
91
- # more fine-grained sharding on the same dim.
92
- # Please check the URL below for _StridedShard.
93
- # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
-
95
- def placement_sort_key(
96
- placement_with_index: tuple[float, Placement]
97
- ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
- index, placement = placement_with_index
99
- is_replicate = placement.is_replicate()
100
- is_shard = placement.is_shard()
101
- is_partial = placement.is_partial()
102
-
103
- assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
- assert not is_partial, "Partial placement is not supported."
105
-
106
- if is_replicate:
107
- return (-1.0, 0, index)
108
- elif is_shard:
109
- if isinstance(placement, _StridedShard):
110
- return (placement.dim, 1 / placement.split_factor, index)
111
- return (placement.dim, 0, index)
112
- else:
113
- raise TypeError(f"Unknown placement type: {type(placement)}")
114
-
115
- placements_with_index: list[tuple[int,
116
- Placement]] = list(enumerate(placements))
117
- placements_with_index = sorted(placements_with_index,
118
- key=placement_sort_key)
119
-
120
- sorted_indices, sorted_placements = zip(*placements_with_index)
121
-
122
- # 2. Permute mesh according to sorted placements.
123
- sorted_mesh = mesh.permute(sorted_indices)
124
-
125
- # 3. Collect list of shard meshes by removing replicate dims
126
- # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
- # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
- num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
-
130
- # merge replicate dims
131
- # shard_meshes became a list of shard meshes with a length of replicate degree
132
- if num_replicates > 0:
133
- sorted_mesh = sorted_mesh.flatten(
134
- 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
- shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
- else:
137
- shard_meshes = [sorted_mesh]
138
- shard_placements = sorted_placements[num_replicates:]
139
-
140
- # assume all shard placements are different
141
- assert len(shard_placements) == len(set(shard_placements))
142
-
143
- # 4. Construct ProcessGroups
144
- # Caution: all groups should be created in the same order in all processes,
145
- # even though each process only needs its own group.
146
-
147
- # To use tensor as dict key, convert it to tuple
148
- def tensor_to_tuple(t):
149
- if isinstance(t, torch.Tensor):
150
- t = t.tolist()
151
- if isinstance(t, list):
152
- return tuple(tensor_to_tuple(x) for x in t)
153
- return t
154
-
155
- my_shard_mesh_as_tuple = None
156
- for shard_mesh in shard_meshes:
157
- assert isinstance(shard_mesh, torch.Tensor)
158
- shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
-
160
- if (my_rank == shard_mesh).any().item():
161
- assert my_shard_mesh_as_tuple is None
162
- my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
-
164
- # update global cache
165
- if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
- shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
- _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
- DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
- shard_process_group,
170
- )
171
-
172
- my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
- my_shard_mesh_as_tuple]
174
-
175
- return my_shard_mesh, my_shard_process_group, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py DELETED
@@ -1,128 +0,0 @@
1
- # MIT License
2
- #
3
- # Copyright (c) 2025 Tianyang Lin
4
- #
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
- #
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
- #
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
- import triton
25
- import triton.language as tl
26
-
27
-
28
- def get_autotune_config():
29
- return [
30
- triton.Config(
31
- {
32
- 'BLOCK_SIZE_M': blk_m,
33
- 'BLOCK_SIZE_K': blk_k,
34
- 'GROUP_SIZE_M': grp_sz
35
- },
36
- num_stages=n_stages,
37
- num_warps=n_warps) for blk_m in [32, 64, 128]
38
- for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
- for n_warps in [4, 8]
40
- ]
41
-
42
-
43
- @triton.autotune(
44
- configs=get_autotune_config(),
45
- key=['M', 'K'],
46
- )
47
- @triton.jit
48
- def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
- GROUP_SIZE_M: tl.constexpr):
51
- """
52
- Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
- The code is a simple adaptation from the triton `matmul` tutorial:
54
- https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
- """
56
- pid = tl.program_id(axis=0)
57
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
- num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
- group_id = pid // num_pid_in_group
61
- first_pid_m = group_id * GROUP_SIZE_M
62
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
- pid_n = (pid % num_pid_in_group) // group_size_m
65
- if pid_m > pid_n:
66
- return
67
-
68
- offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
- offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
- offs_k = tl.arange(0, BLOCK_SIZE_K)
71
- # we use a & b ptrs to denote different rows of x.
72
- a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
- b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
-
75
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
-
77
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
- a = tl.load(a_ptrs,
79
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
- other=0.0)
81
- b = tl.load(b_ptrs,
82
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
- other=0.0)
84
- accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
- a_ptrs += BLOCK_SIZE_K * stride_xk
86
- b_ptrs += BLOCK_SIZE_K * stride_xk
87
- # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
- # https://github.com/triton-lang/triton/issues/2252
89
- c = accumulator.to(x.dtype.element_ty)
90
-
91
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
- offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
- c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
- tl.store(c_ptrs, c, mask=c_mask)
96
-
97
- # transpose and copy
98
- if pid_m < pid_n:
99
- ct_ptrs = y + stride_ym * offs_cn[:,
100
- None] + stride_yn * offs_cm[None, :]
101
- ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
- tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
-
104
-
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
- d_in = d_in.contiguous()
116
- M, K = d_in.shape
117
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
- M, META['BLOCK_SIZE_M']), )
119
- with torch.cuda.device(d_in.device.index):
120
- mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
- d_out.stride(0), d_out.stride(1))
122
-
123
-
124
- def matmul_transpose(d_in):
125
- M, _ = d_in.shape
126
- d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
- matmul_transpose_assign(d_in, d_out)
128
- return d_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-rocm71-x86_64-linux/muon.py DELETED
@@ -1,1268 +0,0 @@
1
- import logging
2
- import math
3
- import types
4
- from collections import defaultdict
5
- from dataclasses import dataclass
6
- from typing import Any, cast
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from torch.distributed import ProcessGroup
11
- from torch.distributed.device_mesh import DeviceMesh
12
- from torch.distributed.tensor import DTensor, Replicate
13
- from torch.distributed.tensor.placement_types import Placement
14
-
15
- from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
- from .matmul_transpose_triton import matmul_transpose_assign
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- COMM_DTYPE = torch.bfloat16
21
- DEFAULT_CHUNK_SIZE_RATIO = 4
22
-
23
-
24
- # This code snippet is a modified version adapted from the following GitHub repositories:
25
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
- # Muon's Newton–Schulz iteration causes high variance in singular values
27
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
- @torch.no_grad()
29
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
- def _zeropower_via_newtonschulz5(G, steps):
31
- """
32
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
- performance at all relative to UV^T, where USV^T = G is the SVD.
39
- """
40
- assert len(G.shape) == 2
41
- assert G.dtype == COMM_DTYPE
42
- X = G # no manual typecast
43
-
44
- if G.size(0) > G.size(1):
45
- X = X.T
46
- # Ensure spectral norm is at most 1
47
- X = X / (X.norm() + 1e-7)
48
- buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
- buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
- # Perform the NS iterations
51
- for a, b, c in [
52
- (4.0848, -6.8946, 2.9270),
53
- (3.9505, -6.3029, 2.6377),
54
- (3.7418, -5.5913, 2.3037),
55
- (2.8769, -3.1427, 1.2046),
56
- (2.8366, -3.0525, 1.2012),
57
- ]:
58
- matmul_transpose_assign(X, buf1)
59
- matmul_transpose_assign(buf1, buf2)
60
- buf1.mul_(b).add_(buf2, alpha=c)
61
- X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
-
63
- if G.size(0) > G.size(1):
64
- X = X.T
65
- return X
66
-
67
-
68
- @dataclass
69
- class _muon_state:
70
- # TODO: use Optional
71
- worker_rank: int
72
- process_group: ProcessGroup
73
- shard_mesh: DeviceMesh
74
- shard_placements: tuple[Placement, ...]
75
- name: str
76
- qk_clip_state: torch.Tensor | None = None
77
- gathered_grad: torch.Tensor | None = None
78
- scattered_u: DTensor | None = None
79
- computed_u: torch.Tensor | None = None
80
- gather_event: torch.cuda.Event | None = None
81
- compute_event: torch.cuda.Event | None = None
82
- scatter_event: torch.cuda.Event | None = None
83
-
84
-
85
- def numel_for_rank(
86
- param: DTensor,
87
- local_rank: int,
88
- state: _muon_state,
89
- ) -> int:
90
- slices = get_slices_of_dtensor(
91
- param,
92
- local_rank,
93
- state.shard_mesh,
94
- state.shard_placements,
95
- )
96
-
97
- numel = 1
98
- for s, dim in zip(slices, param.shape):
99
- start, stop, step = s.indices(dim)
100
- length = max(0, (stop - start + (step - 1)) // step)
101
- numel *= length
102
-
103
- return numel
104
-
105
-
106
- @torch.no_grad()
107
- def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
- """
109
- Pre-allocate gathered_grad buffer on compute_stream
110
- before launching all2all gather
111
- """
112
- with torch.cuda.stream(compute_stream):
113
- for p in params:
114
- state = param_to_state[id(p)]
115
- if rank == state.worker_rank:
116
- state.gathered_grad = torch.empty(p.shape,
117
- dtype=COMM_DTYPE,
118
- device="cuda")
119
- else:
120
- state.gathered_grad = None
121
-
122
- alloc_event = torch.cuda.Event()
123
- alloc_event.record(compute_stream)
124
- return alloc_event
125
-
126
-
127
- @torch.no_grad()
128
- def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
- alloc_event):
130
- """
131
- All2all gathers shards so each owner rank reconstructs its full gradient
132
- """
133
- with torch.cuda.stream(comm_stream):
134
- process_group = param_to_state[id(params[0])].process_group
135
- num_ranks = dist.get_world_size(group=process_group)
136
-
137
- # Construct sending buffers
138
- per_dst = [[] for _ in range(num_ranks)]
139
- send_counts = [0] * num_ranks
140
-
141
- for p in params:
142
- state = param_to_state[id(p)]
143
- dst = state.worker_rank
144
- assert dst < num_ranks
145
- shard_elems = numel_for_rank(p, rank, state)
146
- g = p.grad
147
- g = g.to_local().to(COMM_DTYPE).contiguous()
148
- assert g.numel() == shard_elems
149
- per_dst[dst].append(g.view(-1))
150
- send_counts[dst] += shard_elems
151
-
152
- assert any(
153
- len(v) > 0 for v in per_dst
154
- ), "At least one destination rank must receive a sharded tensor"
155
- # list[list[Tensor]] -> list[Tensor]
156
- per_dst = [t for dst in per_dst for t in dst]
157
-
158
- send_buf = torch.cat(per_dst, dim=0)
159
-
160
- owned_params = [
161
- p for p in params if param_to_state[id(p)].worker_rank == rank
162
- ]
163
-
164
- # Compute receive sizes and allocate receiving buffers
165
- recv_counts = [0] * num_ranks
166
-
167
- for src in range(num_ranks):
168
- total = 0
169
- for p in owned_params:
170
- state = param_to_state[id(p)]
171
- assert state.worker_rank == rank
172
- total += numel_for_rank(p, src, state)
173
- recv_counts[src] = total
174
-
175
- recv_total = sum(recv_counts)
176
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
-
178
- #All2All
179
- logger.debug(f"send_buf size: {send_buf.numel()}, "
180
- f"recv_buf size: {recv_buf.numel()}, "
181
- f"recv_counts: {recv_counts}, "
182
- f"send_counts: {send_counts}, "
183
- f"process_group: {str(process_group)}")
184
- dist.all_to_all_single(
185
- recv_buf,
186
- send_buf,
187
- output_split_sizes=recv_counts,
188
- input_split_sizes=send_counts,
189
- group=process_group,
190
- )
191
-
192
- # Reconstructs gathered grad from the received buffer
193
- #
194
- # recv_buf (num ranks = 3)
195
- #
196
- # From rank 0 From rank 1 From rank 2
197
- # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
- #
199
- # Outer loop:
200
- # rank 0 -> rank 1 -> rank2
201
- #
202
- # Inner loop:
203
- # p1_n -> p2_n -> p3_n
204
-
205
- comm_stream.wait_event(alloc_event)
206
-
207
- off = 0
208
- for src in range(num_ranks):
209
- if recv_counts[src] == 0:
210
- continue
211
-
212
- block = recv_counts[src]
213
- inner_off = 0
214
- for p in owned_params:
215
- state = param_to_state[id(p)]
216
- assert state.worker_rank == rank
217
-
218
- # get the slice of the full dtensor corresponding to rank src.
219
- slices = get_slices_of_dtensor(state.gathered_grad, src,
220
- state.shard_mesh,
221
- state.shard_placements)
222
-
223
- dst = state.gathered_grad[slices]
224
- assert dst._base is state.gathered_grad
225
-
226
- n = dst.numel()
227
- assert n > 0
228
-
229
- sg = recv_buf.narrow(0, off + inner_off, n)
230
- sg = sg.reshape_as(dst)
231
- dst.copy_(sg)
232
-
233
- inner_off += n
234
- off += block
235
-
236
- for p in params:
237
- state = param_to_state[id(p)]
238
- if state.worker_rank == rank:
239
- state.gather_event = torch.cuda.Event()
240
- state.gather_event.record(comm_stream)
241
- else:
242
- state.gathered_grad = None
243
- state.gather_event = None
244
- if none_grad:
245
- p.grad = None
246
-
247
-
248
- @torch.no_grad()
249
- def _compute_u(p, state, steps, rank, compute_stream):
250
- """
251
- On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
- """
253
- with torch.cuda.stream(compute_stream):
254
- if rank == state.worker_rank:
255
- if state.gather_event is None:
256
- raise RuntimeError("Gather event must be set before compute.")
257
- compute_stream.wait_event(state.gather_event)
258
- u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
- state.gathered_grad = None
260
- state.computed_u = u
261
- state.compute_event = torch.cuda.Event()
262
- state.compute_event.record()
263
- else:
264
- state.computed_u = None
265
- state.compute_event = None
266
-
267
-
268
- @torch.no_grad()
269
- def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
- """
271
- Pre-allocate scattered_u buffer on compute_stream
272
- before launching all2all gather
273
- """
274
- with torch.cuda.stream(compute_stream):
275
- for p in params:
276
- state = param_to_state[id(p)]
277
- state.scattered_u = torch.empty_like(p.to_local(),
278
- dtype=COMM_DTYPE)
279
-
280
- alloc_event = torch.cuda.Event()
281
- alloc_event.record(compute_stream)
282
- return alloc_event
283
-
284
-
285
- def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
- """
287
- All2all scatters full gradients to all ranks
288
- """
289
- with torch.cuda.stream(comm_stream):
290
- process_group = param_to_state[id(params[0])].process_group
291
- num_ranks = dist.get_world_size(group=process_group)
292
- owned_params = [
293
- p for p in params if param_to_state[id(p)].worker_rank == rank
294
- ]
295
-
296
- # Construct sending buffer
297
- per_dst = [[] for _ in range(num_ranks)]
298
- send_counts = [0] * num_ranks
299
-
300
- if owned_params:
301
- for p in owned_params:
302
- state = param_to_state[id(p)]
303
- if state.compute_event is None:
304
- raise RuntimeError(
305
- "Compute event must be set before scatter.")
306
- comm_stream.wait_event(state.compute_event)
307
- state.gathered_grad = None
308
-
309
- assert state.computed_u is not None
310
-
311
- u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
-
313
- offset = 0
314
- for dst in range(num_ranks):
315
- # get the slice of the full tensor corresponding to rank dst.
316
- slices = get_slices_of_dtensor(u_full, dst,
317
- state.shard_mesh,
318
- state.shard_placements)
319
- su = u_full[slices].flatten()
320
-
321
- n = su.numel()
322
- assert n > 0
323
-
324
- per_dst[dst].append(su)
325
- send_counts[dst] += n
326
- offset += n
327
-
328
- assert offset == u_full.numel()
329
-
330
- lengths = [len(v) for v in per_dst]
331
- if all(l > 0 for l in lengths):
332
- assert all(
333
- l == lengths[0] for l in lengths
334
- ), "All destination ranks must have the same number of sharded tensor"
335
- # list[list[Tensor]] -> list[Tensor]
336
- per_dst = [t for dst in per_dst for t in dst]
337
- send_buf = torch.cat(per_dst, dim=0)
338
- else:
339
- # all_to_all requires participation from all ranks
340
- # Even non-owner ranks must join the collective call
341
- send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
-
343
- # Compute receive sizes and allocate receiving buffers
344
- recv_counts = [0] * num_ranks
345
-
346
- for src in range(num_ranks):
347
- total = 0
348
- for p in params:
349
- state = param_to_state[id(p)]
350
- if state.worker_rank != src:
351
- continue
352
- total += numel_for_rank(p, rank, state)
353
- recv_counts[src] = total
354
-
355
- recv_total = sum(recv_counts)
356
- assert recv_total > 0
357
- recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
-
359
- #All2All
360
- dist.all_to_all_single(
361
- recv_buf,
362
- send_buf,
363
- output_split_sizes=recv_counts,
364
- input_split_sizes=send_counts,
365
- group=process_group,
366
- )
367
-
368
- # Copy to pre-allocated scattered_u buffer from the received buffer
369
- #
370
- # recv_buf (num ranks = 3, local_rank = 0)
371
- #
372
- # From rank 0 From rank 1 From rank 2
373
- # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
- #
375
- # Outer loop:
376
- # rank 0 -> rank 1 -> rank2
377
- #
378
- # Inner loop:
379
- # src(0) : p1_0 -> p2_0 -> p3_0
380
- # src(1) : p4_0
381
- # src(2) : p5_0 -> p6_0
382
-
383
- comm_stream.wait_event(alloc_event)
384
-
385
- off = 0
386
- for src in range(num_ranks):
387
- block = recv_counts[src]
388
- if block == 0:
389
- continue
390
-
391
- inner_off = 0
392
- for p in params:
393
- state = param_to_state[id(p)]
394
- if state.worker_rank != src:
395
- continue
396
- n = numel_for_rank(p, rank, state)
397
- assert n > 0
398
-
399
- flat_local = recv_buf.narrow(0, off + inner_off,
400
- n).view_as(p.to_local())
401
- state.scattered_u.copy_(flat_local)
402
-
403
- state.scatter_event = torch.cuda.Event()
404
- state.scatter_event.record(comm_stream)
405
- inner_off += n
406
-
407
- assert inner_off == block
408
- off += block
409
-
410
-
411
- def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
- compute_stream):
413
- """
414
- Update sharded parameter p with the scattered_u.
415
- Only worker_rank frees computed_u.
416
- """
417
- with torch.cuda.stream(compute_stream):
418
- if state.scatter_event is None:
419
- raise RuntimeError("Scatter event must be set before update")
420
- compute_stream.wait_event(state.scatter_event)
421
- u_dtensor = DTensor.from_local(
422
- state.scattered_u,
423
- placements=p.placements,
424
- device_mesh=p.device_mesh,
425
- )
426
-
427
- state.scattered_u = u_dtensor
428
-
429
- if rank == state.worker_rank:
430
- # Free computed_u
431
- state.computed_u = None
432
-
433
- Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
- state.scattered_u = None
435
- u_dtensor = None
436
-
437
- scales_full = Muon._compute_scales(
438
- p,
439
- state.qk_clip_state) if state.qk_clip_state is not None else None
440
- if scales_full is not None:
441
- # Have to slice scales_full among dim 0
442
- weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
- state.shard_placements)
444
- ratio = p.shape[0] // scales_full.shape[0]
445
- scales_slice = slice(
446
- None if weight_slices[0].start is None else
447
- weight_slices[0].start // ratio,
448
- None if weight_slices[0].stop is None else
449
- weight_slices[0].stop // ratio,
450
- None,
451
- )
452
-
453
- scales_local = scales_full[scales_slice]
454
- scales_local = DTensor.from_local(
455
- scales_local,
456
- placements=p.placements,
457
- device_mesh=p.device_mesh,
458
- )
459
- Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
-
461
-
462
- def default_is_muon(name, x):
463
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
- return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
-
466
-
467
- def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
- muon_params, muon_names = [], []
469
- non_muon_params = []
470
-
471
- for n, p in model.named_parameters():
472
- if not p.requires_grad:
473
- continue
474
- if is_muon_func(n, p):
475
- muon_params.append(p)
476
- muon_names.append(n)
477
- else:
478
- non_muon_params.append(p)
479
-
480
- return [
481
- {
482
- "params": muon_params,
483
- "names": muon_names,
484
- "use_muon": True,
485
- },
486
- {
487
- "params": non_muon_params,
488
- "use_muon": False,
489
- },
490
- ]
491
-
492
-
493
- def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
- """
495
- Parse a parameter name to check if it is a query/key projection layer
496
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
-
498
- Returns:
499
- (kind, layer_idx) or (None, -1) if not matched.
500
-
501
- Example:
502
- 'model.3.attn.wq.weight' -> ('wq', 3)
503
- 'model.5.attn.wk.weight' -> ('wk', 5)
504
- 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
- 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
- 'model.4.attn.v_proj.weight' -> (None, -1)
507
- """
508
- parts = name.split('.')
509
- if len(parts) < 3:
510
- return None, -1
511
-
512
- kind = parts[-2]
513
-
514
- layer_idx = -1
515
- for part in reversed(parts):
516
- if part.isdigit():
517
- layer_idx = int(part)
518
- break
519
-
520
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
- return kind, layer_idx
522
-
523
- return None, -1
524
-
525
-
526
- @dataclass
527
- class QKClipInfo:
528
- """Per-parameter dynamic info computed from config + runtime logits."""
529
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
- indices: list[int] # which heads to consider for clipping
531
- head_dim: int # from config
532
- threshold: float # from config
533
- logit: torch.Tensor | None
534
-
535
-
536
- class Muon(torch.optim.Optimizer):
537
- """
538
- Muon - MomentUm Orthogonalized by Newton-schulz
539
-
540
- Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
- matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
- the advantage that it can be stably run in bfloat16 on the GPU.
544
-
545
- Some warnings:
546
- - We believe this optimizer is unlikely to work well for training with small batch size.
547
- - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
-
549
- Arguments:
550
- model: The model to be optimized by Muon.
551
- is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
- lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
- momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
- nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
- ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
- weight_decay: The weight decay for Muon and AdamW.
557
- {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
- adamw_lr: The learning rate for the internal AdamW.
559
- adamw_betas: The betas for the internal AdamW.
560
- adamw_eps: The epsilon for the internal AdamW.
561
- none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
- debug: Whether to print debug information.
563
- clip_info : Configuration for QK clipping. Expected keys:
564
- - "q_indices" (list[int]): Indices of query heads to consider.
565
- - "k_indices" (list[int]): Indices of key heads to consider.
566
- - "head_dim" (int): Dimensionality of each attention head.
567
- - "threshold" (float): Threshold value; heads whose QK logits exceed
568
- this value will be scaled down.
569
- Default is:
570
- {
571
- "q_indices": [],
572
- "k_indices": [],
573
- "head_dim": 128,
574
- "threshold": 100
575
- }
576
- warmup_step : How many all2all gather, compute operations are launched in advance
577
- before the corresponding all2all scatter steps begin.
578
- A higher warmup_step increases memory usage but can improve
579
- performance by overlapping communication.
580
- Parallel muon only.
581
- chunk_size : Batch size of parameters to process in each
582
- all2all gather/compute/scatter step.
583
- Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
- use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
- For testing purpose only.
586
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
- """
588
-
589
- def __init__(self,
590
- params,
591
- lr=1e-3,
592
- momentum=0.95,
593
- nesterov=True,
594
- ns_steps=5,
595
- weight_decay=0.1,
596
- adamw_betas=(0.9, 0.95),
597
- adamw_eps=1e-8,
598
- none_grad=True,
599
- debug=False,
600
- clip_config={
601
- "q_indices": [],
602
- "k_indices": [],
603
- "head_dim": 128,
604
- "threshold": 100
605
- },
606
- warmup_step=5,
607
- chunk_size=-1,
608
- use_distributed_muon=False,
609
- small_param_numel_threshold=65536):
610
- defaults = dict(
611
- lr=lr,
612
- weight_decay=weight_decay,
613
- momentum=momentum,
614
- nesterov=nesterov,
615
- ns_steps=ns_steps,
616
- adamw_betas=adamw_betas,
617
- adamw_eps=adamw_eps,
618
- none_grad=none_grad,
619
- use_muon=True,
620
- )
621
- error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
- instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
-
624
- if isinstance(params, types.GeneratorType):
625
- raise ValueError(error_message.format(idx=0) + instruction_code)
626
- for _idx, param_group in enumerate(params):
627
- if param_group.get("use_muon", None) is None:
628
- raise ValueError(
629
- error_message.format(idx=_idx) + instruction_code)
630
-
631
- super().__init__(params, defaults)
632
-
633
- self.rank = None
634
-
635
- self.comm_stream = torch.cuda.Stream()
636
- self.compute_stream = torch.cuda.Stream()
637
- self.debug = debug
638
- self.clip_config = clip_config
639
- self.warmup_step = warmup_step
640
- self.chunk_size = chunk_size
641
- self.use_distributed_muon = use_distributed_muon
642
- self.small_param_numel_threshold = small_param_numel_threshold
643
-
644
- def _calc_flops(self, G, steps):
645
- assert len(G.shape) == 2
646
- M, N = G.shape
647
- if M > N:
648
- M, N = N, M
649
-
650
- return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
-
652
- def adjust_lr_for_muon(self, lr, param_shape):
653
- A, B = param_shape[:2]
654
- # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
- # as describted in the paper
656
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
- adjusted_lr = lr * adjusted_ratio
658
- return adjusted_lr
659
-
660
- def set_rank_once(self, rank):
661
- if self.rank is None:
662
- self.rank = rank
663
- else:
664
- assert self.rank == rank
665
-
666
- def get_shard_mesh(self, p):
667
- """
668
- Get the shard mesh for a parameter p on the given rank.
669
- """
670
- assert isinstance(
671
- p, DTensor), "Parallel Muon only supports DTensor parameters."
672
-
673
- shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
- p.placements, p.device_mesh)
675
-
676
- # set rank with the local rank in the shard process group
677
- self.set_rank_once(dist.get_rank(group=shard_pg))
678
-
679
- return shard_mesh, shard_pg, shard_placements
680
-
681
- def init_state_and_assign_params(self, names, params, group, qk_logits):
682
- param_to_state = {}
683
- param_to_flops = {}
684
-
685
- total_flops = 0
686
- for p in params:
687
- g = p.grad
688
- if g is None:
689
- continue
690
- assert g.ndim == 2, "Muon only supports 2D parameters."
691
-
692
- flops = self._calc_flops(g, group["ns_steps"])
693
- param_to_flops[id(p)] = flops
694
- total_flops += flops
695
-
696
- if self.debug:
697
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
- flush=True)
699
-
700
- paired = list(zip(names, params))
701
-
702
- paired_sorted = sorted(paired,
703
- key=lambda x: param_to_flops[id(x[1])],
704
- reverse=True)
705
-
706
- names_sorted, params_sorted = zip(*paired_sorted)
707
- ordered_names = list(names_sorted)
708
- ordered_params = list(params_sorted)
709
-
710
- round_robin = 0
711
- mesh = ordered_params[0].device_mesh
712
- placements = ordered_params[0].placements
713
-
714
- shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
- ordered_params[0])
716
- shard_mesh_flattened = shard_mesh.mesh.flatten()
717
- num_ranks = dist.get_world_size(group=shard_pg)
718
-
719
- for n, p in zip(ordered_names, ordered_params):
720
- if mesh != p.device_mesh:
721
- raise ValueError("All parameters must be on the same mesh.")
722
- if placements != p.placements:
723
- raise ValueError("All parameters must have same placements.")
724
-
725
- worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
- round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
-
729
- param_to_state[id(p)] = _muon_state(
730
- worker_rank=worker_rank,
731
- process_group=shard_pg,
732
- shard_mesh=shard_mesh,
733
- shard_placements=shard_placements,
734
- name=n,
735
- qk_clip_state=qk_clip_state,
736
- )
737
-
738
- return param_to_state, ordered_params
739
-
740
- def base(self, names, params, group, lr, weight_decay, momentum,
741
- qk_logits):
742
- # generate weight updates in distributed fashion
743
- for n, p in zip(names, params):
744
- g = p.grad
745
- if g is None:
746
- continue
747
- if g.ndim > 2:
748
- g = g.view(g.size(0), -1)
749
- assert g is not None
750
-
751
- g = self._update_g(p, g, group, momentum)
752
-
753
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
- steps=group["ns_steps"])
755
-
756
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
-
759
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
-
761
- scales_full = self._compute_scales(
762
- p, qk_clip_state) if qk_clip_state is not None else None
763
- if scales_full is not None:
764
- Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
-
766
- def distributed_muon(
767
- self,
768
- names: list[str],
769
- params: list[torch.nn.Parameter],
770
- group: dict[str, Any],
771
- lr: float,
772
- weight_decay: float,
773
- momentum: float,
774
- qk_logits: list[torch.Tensor | DTensor] | None,
775
- ):
776
- """ Implementation of Distributed Muon by Liu et al. """
777
-
778
- for n, p in zip(names, params):
779
- g = p.grad
780
- if g is None:
781
- continue
782
- if g.ndim > 2:
783
- g = g.view(g.size(0), -1)
784
- assert g is not None
785
-
786
- g = self._update_g(p, g, group, momentum)
787
-
788
- # Gather G
789
- if isinstance(p.data, DTensor):
790
- g_full = g.full_tensor()
791
- p_full = p.data.full_tensor()
792
- else:
793
- g_full = g
794
- p_full = p
795
-
796
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
- steps=group["ns_steps"])
798
-
799
- adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
- Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
-
802
- qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
-
804
- scales_full = self._compute_scales(
805
- p_full, qk_clip_state) if qk_clip_state is not None else None
806
-
807
- if scales_full is not None:
808
- Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
-
810
- if isinstance(p.data, DTensor):
811
- ndims = len(p.device_mesh.mesh.shape)
812
- p_replicate = DTensor.from_local(
813
- p_full,
814
- device_mesh=p.device_mesh,
815
- placements=[Replicate() for _ in range(ndims)],
816
- )
817
-
818
- p_sharded = p_replicate.redistribute(
819
- device_mesh=p.device_mesh,
820
- placements=p.placements,
821
- )
822
-
823
- p.copy_(p_sharded)
824
-
825
- def _update_g(self, p, g, group, momentum):
826
- # calc update
827
- state = self.state[p]
828
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
- torch.add(g, buf, alpha=momentum, out=buf)
830
- if group["nesterov"]:
831
- g.add_(buf, alpha=momentum)
832
- return g
833
- return buf
834
-
835
- @staticmethod
836
- def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
- if isinstance(p, torch.nn.Parameter):
838
- # apply weight decay
839
- p.data.mul_(1 - lr * weight_decay)
840
- # apply update
841
- p.data.add_(u, alpha=-adjusted_lr)
842
- else:
843
- p.mul_(1 - lr * weight_decay)
844
- p.add_(u, alpha=-adjusted_lr)
845
-
846
- def get_qk_clip_info(self, n, qk_logits):
847
- if self.clip_config is None:
848
- return None
849
-
850
- head_dim = self.clip_config.get('head_dim')
851
- threshold = self.clip_config.get('threshold')
852
- kind, layer_idx = parse_qk_layer(n)
853
-
854
- logit, indices = None, []
855
- if qk_logits is not None and kind is not None:
856
- logit = qk_logits[layer_idx]
857
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
- indices = self.clip_config.get(indices_key, []) or []
859
-
860
- if isinstance(logit, DTensor):
861
- # In TP settings, qk_logits may be DTensor
862
- # We convert it to full tensor here for simplicity
863
- logit = logit.full_tensor()
864
-
865
- return QKClipInfo(
866
- kind=kind,
867
- indices=indices,
868
- head_dim=head_dim,
869
- threshold=threshold,
870
- logit=logit,
871
- )
872
-
873
- @staticmethod
874
- def _compute_scales(p, qk_clip_state):
875
- kind = qk_clip_state.kind
876
- indices = qk_clip_state.indices
877
- head_dim = qk_clip_state.head_dim
878
- threshold = qk_clip_state.threshold
879
- logit = qk_clip_state.logit
880
-
881
- H_global = p.shape[0] // head_dim
882
- scales_full = torch.ones(H_global, device=p.data.device)
883
- scaling = 0
884
-
885
- for logit_idx, head_idx in enumerate(indices):
886
- v_ele = float(logit[logit_idx])
887
- if v_ele > threshold:
888
- new_scale = math.sqrt(threshold / v_ele)
889
- if new_scale < scales_full[head_idx]:
890
- scales_full[head_idx] = new_scale
891
- logger.info(
892
- f"[{kind}] Head {head_idx} exceeded threshold "
893
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
- )
895
- scaling += 1
896
-
897
- return scales_full if scaling > 0 else None
898
-
899
- @staticmethod
900
- def _qk_clip(p, scales, head_dim):
901
- if isinstance(p, torch.nn.Parameter):
902
- W = p.data.view(-1, head_dim, p.data.shape[1])
903
- W.mul_(scales.view(-1, 1, 1))
904
- else:
905
- W = p.view(-1, head_dim, p.shape[1])
906
- W.mul_(scales.view(-1, 1, 1))
907
-
908
- def parallel(self, names, params, group, lr, weight_decay, momentum,
909
- qk_logits):
910
- """
911
- Perform a parallel optimization step using Muon.
912
- """
913
-
914
- for p in params:
915
- g = p.grad
916
- if g is None:
917
- continue
918
- if g.ndim > 2:
919
- g = g.view(g.size(0), -1)
920
-
921
- # Update g in the local rank
922
- g = self._update_g(
923
- p,
924
- g,
925
- group,
926
- momentum=momentum,
927
- )
928
- p.grad = g
929
-
930
- param_to_state, ordered_params = self.init_state_and_assign_params(
931
- names, params, group, qk_logits)
932
-
933
- assert self.rank is not None
934
-
935
- def enqueue_all2all_gather(start_idx, chunk_size):
936
- target_params = ordered_params[start_idx:start_idx + chunk_size]
937
- if target_params:
938
- alloc_event = _alloc_gathered_grad(target_params,
939
- param_to_state, self.rank,
940
- self.compute_stream)
941
- _all2all_gather(target_params, param_to_state, self.rank,
942
- self.comm_stream, group["none_grad"],
943
- alloc_event)
944
-
945
- def enqueue_computes(start_idx, chunk_size):
946
- for p in ordered_params[start_idx:start_idx + chunk_size]:
947
- state = param_to_state[id(p)]
948
- _compute_u(p, state, group["ns_steps"], self.rank,
949
- self.compute_stream)
950
-
951
- def enqueue_all2all_scatter(start_idx, chunk_size):
952
- target_params = ordered_params[start_idx:start_idx + chunk_size]
953
- if target_params:
954
- alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
- self.rank,
956
- self.compute_stream)
957
- _all2all_scatter(target_params, param_to_state, self.rank,
958
- self.comm_stream, alloc_event)
959
-
960
- def enqueue_update_param(start_idx, chunk_size):
961
- for p in ordered_params[start_idx:start_idx + chunk_size]:
962
- state = param_to_state[id(p)]
963
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
- _update_param(p, state, lr, adjusted_lr, weight_decay,
965
- self.rank, self.compute_stream)
966
-
967
- if self.chunk_size == -1:
968
- shard_ranks = dist.get_world_size(param_to_state[id(
969
- params[0])].process_group)
970
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
- elif self.chunk_size > 0:
972
- chunk_size = self.chunk_size
973
- else:
974
- raise ValueError("chunk_size must be -1 or a positive integer.")
975
-
976
- # Wait grad update
977
- self.comm_stream.wait_stream(torch.cuda.current_stream())
978
-
979
- warmup_step = self.warmup_step
980
- for i in range(0, warmup_step):
981
- enqueue_all2all_gather(i * chunk_size, chunk_size)
982
- enqueue_computes(i * chunk_size, chunk_size)
983
-
984
- for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
- enqueue_all2all_scatter(i, chunk_size)
986
- enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
- enqueue_update_param(i, chunk_size)
988
- enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
-
990
- # Wait the last update_param to finish
991
- torch.cuda.current_stream().wait_stream(self.compute_stream)
992
-
993
- @staticmethod
994
- def _fused_adamw(
995
- params: list[torch.Tensor],
996
- grads: list[torch.Tensor],
997
- exp_avgs: list[torch.Tensor],
998
- exp_avg_sqs: list[torch.Tensor],
999
- max_exp_avg_sqs: list[torch.Tensor],
1000
- state_steps: list[torch.Tensor],
1001
- amsgrad: bool,
1002
- beta1: float,
1003
- beta2: float,
1004
- lr: float | torch.Tensor,
1005
- weight_decay: float,
1006
- eps: float,
1007
- maximize: bool,
1008
- ) -> None:
1009
- if not params:
1010
- return
1011
-
1012
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
- # treating it as a scalar.
1014
- lr_dict: DeviceDict | None = ({
1015
- lr.device: lr
1016
- } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
- None)
1018
- grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
- [
1020
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
- state_steps
1022
- ] # type: ignore[list-item]
1023
- )
1024
- for (device, _), (
1025
- (
1026
- device_params_,
1027
- device_grads_,
1028
- device_exp_avgs_,
1029
- device_exp_avg_sqs_,
1030
- device_max_exp_avg_sqs,
1031
- device_state_steps_,
1032
- ),
1033
- _,
1034
- ) in grouped_tensors.items():
1035
- device_params = cast(list[torch.Tensor], device_params_)
1036
- device_grads = cast(list[torch.Tensor], device_grads_)
1037
- device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
- device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
- device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
-
1041
- if lr_dict is not None and device not in lr_dict:
1042
- lr_dict[device] = lr.to(
1043
- device=device,
1044
- non_blocking=True) # type: ignore[union-attr]
1045
- lr = lr_dict[device]
1046
- torch._foreach_add_(device_state_steps, 1)
1047
- func = torch._fused_adamw_
1048
- func(
1049
- device_params,
1050
- device_grads,
1051
- device_exp_avgs,
1052
- device_exp_avg_sqs,
1053
- device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
- device_state_steps,
1055
- amsgrad=amsgrad,
1056
- lr=lr, # type: ignore[arg-type]
1057
- beta1=beta1,
1058
- beta2=beta2,
1059
- weight_decay=weight_decay,
1060
- eps=eps,
1061
- maximize=maximize,
1062
- )
1063
-
1064
- def _step_muon(self, group, qk_logits=None):
1065
- params = group["params"]
1066
- lr = group["lr"]
1067
- weight_decay = group["weight_decay"]
1068
- momentum = group["momentum"]
1069
- names = group["names"]
1070
-
1071
- param_dtensors = []
1072
- name_dtensors = []
1073
-
1074
- param_tensors = []
1075
- name_tensors = []
1076
-
1077
- param_dtensors_small = []
1078
- name_dtensors_small = []
1079
-
1080
- if self.use_distributed_muon:
1081
- self.distributed_muon(names=names,
1082
- params=params,
1083
- group=group,
1084
- lr=lr,
1085
- weight_decay=weight_decay,
1086
- momentum=momentum,
1087
- qk_logits=qk_logits)
1088
- return
1089
-
1090
- # For simplicity, we use distributed Muon for small parameters
1091
- # whose number of elements is below a threshold.
1092
- for n, p in zip(names, params):
1093
- if p is None or p.grad is None:
1094
- continue
1095
- if isinstance(p.data, DTensor):
1096
- if all(
1097
- isinstance(placement, Replicate)
1098
- for placement in p.placements):
1099
- param_tensors.append(p)
1100
- name_tensors.append(n)
1101
- elif p.data.numel() <= self.small_param_numel_threshold:
1102
- param_dtensors_small.append(p)
1103
- name_dtensors_small.append(n)
1104
- else:
1105
- param_dtensors.append(p)
1106
- name_dtensors.append(n)
1107
- elif isinstance(p.data, torch.Tensor):
1108
- param_tensors.append(p)
1109
- name_tensors.append(n)
1110
- else:
1111
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
-
1113
- logger.debug(
1114
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
- f"{len(param_dtensors_small)} Small DTensors")
1116
-
1117
- def group_dtensors(dtensors, names):
1118
- # To support different placements, we group parameters by placements
1119
- # and run parallel Muon on each group.
1120
-
1121
- placement_to_params = defaultdict(lambda: ([], []))
1122
- # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
-
1124
- assert len(dtensors) == len(names)
1125
- for p, n in zip(dtensors, names):
1126
- placement_to_params[tuple([p.placements,
1127
- p.device_mesh])][0].append(n)
1128
- placement_to_params[tuple([p.placements,
1129
- p.device_mesh])][1].append(p)
1130
- return placement_to_params
1131
-
1132
- if len(param_dtensors_small) > 0:
1133
- if not dist.is_initialized():
1134
- raise RuntimeError(
1135
- "Parallel Muon requires torch.distributed to be initialized."
1136
- )
1137
-
1138
- self.distributed_muon(
1139
- params=param_dtensors_small,
1140
- names=name_dtensors_small,
1141
- group=group,
1142
- lr=lr,
1143
- weight_decay=weight_decay,
1144
- momentum=momentum,
1145
- qk_logits=qk_logits,
1146
- )
1147
-
1148
- if len(param_dtensors) > 0:
1149
- if not dist.is_initialized():
1150
- raise RuntimeError(
1151
- "Parallel Muon requires torch.distributed to be initialized."
1152
- )
1153
-
1154
- dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
- for _, (names, params) in dtensor_group.items():
1156
- self.parallel(
1157
- names,
1158
- params,
1159
- group,
1160
- lr=lr,
1161
- weight_decay=weight_decay,
1162
- momentum=momentum,
1163
- qk_logits=qk_logits,
1164
- )
1165
-
1166
- if len(param_tensors) > 0:
1167
- self.base(
1168
- name_tensors,
1169
- param_tensors,
1170
- group,
1171
- lr=lr,
1172
- weight_decay=weight_decay,
1173
- momentum=momentum,
1174
- qk_logits=qk_logits,
1175
- )
1176
-
1177
- def _step_adamw_params(self, params, group):
1178
- params_with_grads = []
1179
- grads = []
1180
- moment1 = []
1181
- moment2 = []
1182
- max_exp_avg_sqs = []
1183
- state_steps = []
1184
- lr = group["lr"]
1185
- beta1, beta2 = group["adamw_betas"]
1186
- eps = group["adamw_eps"]
1187
- weight_decay = group["weight_decay"]
1188
-
1189
- for p in params:
1190
- g = p.grad
1191
- if g is None:
1192
- continue
1193
- state = self.state[p]
1194
- params_with_grads.append(p)
1195
- grads.append(g)
1196
- if "step" not in state:
1197
- state["step"] = (torch.zeros((),
1198
- dtype=torch.float32,
1199
- device=p.device))
1200
- state["moment1"] = torch.zeros_like(g)
1201
- state["moment2"] = torch.zeros_like(g)
1202
- moment1.append(state["moment1"])
1203
- moment2.append(state["moment2"])
1204
- if not isinstance(state["step"], torch.Tensor):
1205
- step_tensor = torch.tensor(state["step"],
1206
- dtype=torch.float32,
1207
- device=p.device)
1208
- else:
1209
- step_tensor = state["step"]
1210
- state_steps.append(step_tensor)
1211
-
1212
- self._fused_adamw(
1213
- params_with_grads,
1214
- grads,
1215
- moment1,
1216
- moment2,
1217
- max_exp_avg_sqs,
1218
- state_steps,
1219
- amsgrad=False,
1220
- beta1=beta1,
1221
- beta2=beta2,
1222
- lr=lr,
1223
- weight_decay=weight_decay,
1224
- eps=eps,
1225
- maximize=False,
1226
- )
1227
-
1228
- def _step_adamw(self, group):
1229
- params = group["params"]
1230
-
1231
- # group params with it's type and placement
1232
- placement_to_params: dict[tuple[Placement | type,
1233
- DeviceMesh | None]] = defaultdict(list)
1234
- for p in params:
1235
- match p:
1236
- case DTensor():
1237
- placement_to_params[tuple([p.placements,
1238
- p.device_mesh])].append(p)
1239
- case torch.Tensor():
1240
- placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
-
1242
- for params in placement_to_params.values():
1243
- self._step_adamw_params(params, group)
1244
-
1245
- @torch.no_grad
1246
- def step(self, closure=None, qk_logits=None):
1247
- """Perform a single optimization step.
1248
-
1249
- Args:
1250
- closure (Callable, optional): A closure that reevaluates the model
1251
- and returns the loss.
1252
- qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
- to 1D tensors of shape (num_heads,), representing the maximum
1254
- QK logits across all tokens, computed as
1255
- (1 / sqrt(head_dim)) * (Q @ K^T).
1256
- """
1257
- loss = None
1258
- if closure is not None:
1259
- with torch.enable_grad():
1260
- loss = closure()
1261
-
1262
- for group in self.param_groups:
1263
- if group["use_muon"]:
1264
- self._step_muon(group, qk_logits=qk_logits)
1265
- else:
1266
- self._step_adamw(group)
1267
-
1268
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/{torch210-cxx11-cu126-x86_64-linux → torch26-cxx11-cu118-x86_64-linux/optimizer}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-rocm70-x86_64-linux → torch26-cxx11-cu118-x86_64-linux/optimizer}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/{torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf
3
- size 1940944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c
3
+ size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ @torch.no_grad()
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # )
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ gather_event: torch.cuda.Event | None = None
52
+ compute_event: torch.cuda.Event | None = None
53
+
54
+
55
+ @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ if state.gathered_grad is not None:
74
+ raise RuntimeError(
75
+ "Gather event already exists, which should not happen."
76
+ )
77
+ state.gathered_grad = torch.cat(gather_list, dim=0)
78
+ state.gather_event = torch.cuda.Event()
79
+ state.gather_event.record()
80
+ else:
81
+ state.gathered_grad = None
82
+ state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
+
86
+
87
+ @torch.no_grad()
88
+ def _compute_u(state, steps, rank, compute_stream):
89
+ with torch.cuda.stream(compute_stream):
90
+ if rank == state.worker_rank:
91
+ if state.gather_event is None:
92
+ raise RuntimeError("Gather event must be set before compute.")
93
+ compute_stream.wait_event(state.gather_event)
94
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
95
+ state.computed_u = u
96
+ state.compute_event = torch.cuda.Event()
97
+ state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
+ else:
101
+ state.computed_u = None
102
+ state.compute_event = None
103
+
104
+
105
+ @torch.no_grad()
106
+ def _scatter(p, state, lr, wd, rank, comm_stream):
107
+ u = state.computed_u
108
+ mesh = p.device_mesh
109
+
110
+ with torch.cuda.stream(comm_stream):
111
+ if rank == state.worker_rank:
112
+ if state.compute_event is None:
113
+ raise RuntimeError("Compute event must be set before scatter.")
114
+ comm_stream.wait_event(state.compute_event)
115
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
+ else:
117
+ scatter_list = None
118
+
119
+ u = torch.empty_like(p.to_local())
120
+ torch.distributed.scatter(
121
+ u,
122
+ scatter_list=scatter_list,
123
+ src=state.worker_rank,
124
+ group=mesh.get_group(),
125
+ )
126
+ if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
+ u = DTensor.from_local(
130
+ u,
131
+ placements=p.placements,
132
+ device_mesh=mesh,
133
+ )
134
+ p.data.mul_(1 - lr * wd)
135
+ p.data.add_(u, alpha=-lr)
136
+
137
+
138
+ class Muon(torch.optim.Optimizer):
139
+ """
140
+ Muon - MomentUm Orthogonalized by Newton-schulz
141
+
142
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
143
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
144
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
145
+ the advantage that it can be stably run in bfloat16 on the GPU.
146
+
147
+ Some warnings:
148
+ - We believe this optimizer is unlikely to work well for training with small batch size.
149
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
150
+
151
+ Arguments:
152
+ muon_params: The parameters to be optimized by Muon.
153
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
154
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
155
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
156
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
157
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
158
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
159
+ adamw_lr: The learning rate for the internal AdamW.
160
+ adamw_betas: The betas for the internal AdamW.
161
+ adamw_eps: The epsilon for the internal AdamW.
162
+ adamw_wd: The weight decay for the internal AdamW.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ is_muon_func,
169
+ lr=1e-3,
170
+ momentum=0.95,
171
+ nesterov=True,
172
+ ns_steps=5,
173
+ adamw_wd=0.1,
174
+ adamw_betas=(0.9, 0.95),
175
+ adamw_eps=1e-8,
176
+ none_grad=True,
177
+ debug=False,
178
+ ):
179
+ defaults = dict(
180
+ lr=lr,
181
+ wd=adamw_wd,
182
+ momentum=momentum,
183
+ nesterov=nesterov,
184
+ ns_steps=ns_steps,
185
+ adamw_betas=adamw_betas,
186
+ adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
+ )
189
+
190
+ super().__init__(model.parameters(), defaults)
191
+ self.is_muon_func = is_muon_func
192
+ self.model = model
193
+
194
+ if not dist.is_initialized():
195
+ raise RuntimeError(
196
+ "Muon optimizer requires distributed training to be initialized."
197
+ )
198
+
199
+ self.rank = dist.get_rank()
200
+
201
+ self.comm_stream = torch.cuda.Stream()
202
+ self.compute_stream = torch.cuda.Stream()
203
+ self.debug = debug
204
+
205
+ def __setstate__(self, state):
206
+ # Sort parameters into those for which we will use Muon, and those for which we will not
207
+ super().__setstate__(state)
208
+ for name, p in self.model.named_parameters():
209
+ if self.is_muon_func(p, name):
210
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
211
+ assert p.ndim == 2, p.ndim
212
+ self.state[p]["use_muon"] = True
213
+ self.state[p]["orig_shape"] = p.shape
214
+ else:
215
+ # Do not use Muon for parameters in adamw_params
216
+ self.state[p]["use_muon"] = False
217
+
218
+ def _calc_flops(self, G, steps):
219
+ assert len(G.shape) == 2
220
+ M, N = G.shape
221
+ if M > N:
222
+ M, N = N, M
223
+
224
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
225
+
226
+ def adjust_lr_for_muon(self, lr, param_shape):
227
+ A, B = param_shape[:2]
228
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
229
+ # as describted in the paper
230
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
231
+ adjusted_lr = lr * adjusted_ratio
232
+ return adjusted_lr
233
+
234
+ def init_state_and_assign_params(self, params, group):
235
+ param_to_state = {}
236
+ param_to_flops = {}
237
+
238
+ total_flops = 0
239
+ for p in params:
240
+ g = p.grad
241
+ if g is None:
242
+ continue
243
+ assert g.ndim == 2, "Muon only supports 2D parameters."
244
+
245
+ flops = self._calc_flops(g, group["ns_steps"])
246
+ param_to_flops[id(p)] = flops
247
+ total_flops += flops
248
+
249
+ if self.debug:
250
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
251
+
252
+ ordered_params = sorted(
253
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
254
+ )
255
+
256
+ round_robin = 0
257
+ mesh = None
258
+ for p in ordered_params:
259
+ if mesh is None:
260
+ mesh = p.device_mesh
261
+ if mesh.ndim != 1:
262
+ raise NotImplementedError(
263
+ "Muon requires a 1D mesh for distributed training yet."
264
+ )
265
+ elif mesh != p.device_mesh:
266
+ raise ValueError("All parameters must be on the same mesh.")
267
+
268
+ param_to_state[id(p)] = _muon_state()
269
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
270
+
271
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
272
+
273
+ return param_to_state, ordered_params
274
+
275
+ def base(self, params, group, lr, wd, momentum):
276
+ # generate weight updates in distributed fashion
277
+ for p in params:
278
+ g = p.grad
279
+ if g is None:
280
+ continue
281
+ if g.ndim > 2:
282
+ g = g.view(g.size(0), -1)
283
+ assert g is not None
284
+
285
+ # calc update
286
+ state = self.state[p]
287
+ if "momentum_buffer" not in state:
288
+ state["momentum_buffer"] = torch.zeros_like(g)
289
+ buf = state["momentum_buffer"]
290
+ buf.mul_(momentum).add_(g)
291
+ if group["nesterov"]:
292
+ g = g.add(buf, alpha=momentum)
293
+ else:
294
+ g = buf
295
+
296
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
297
+
298
+ # scale update
299
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
+
301
+ # apply weight decay
302
+ p.data.mul_(1 - lr * wd)
303
+
304
+ # apply update
305
+ p.data.add_(u, alpha=-adjusted_lr)
306
+
307
+ def _update_g(self, p, g, group, momentum):
308
+ # calc update
309
+ state = self.state[p]
310
+ if "momentum_buffer" not in state:
311
+ state["momentum_buffer"] = torch.zeros_like(g)
312
+ buf = state["momentum_buffer"]
313
+ buf.mul_(momentum).add_(g)
314
+ if group["nesterov"]:
315
+ g = g.add(buf, alpha=momentum)
316
+ else:
317
+ g = buf
318
+ return g
319
+
320
+ def _update_p(self, p, u, lr, wd):
321
+ # scale update
322
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
+ # apply weight decay
324
+ p.data.mul_(1 - lr * wd)
325
+ # apply update
326
+ p.data.add_(u, alpha=-adjusted_lr)
327
+
328
+ def parallel(self, params, group, lr, wd, momentum):
329
+ """
330
+ Perform a parallel optimization step using Muon.
331
+ """
332
+
333
+ for p in params:
334
+ g = p.grad
335
+ if g is None:
336
+ continue
337
+ if g.ndim > 2:
338
+ g = g.view(g.size(0), -1)
339
+
340
+ # Update g in the local rank
341
+ g = self._update_g(
342
+ p,
343
+ g,
344
+ group,
345
+ momentum=momentum,
346
+ )
347
+ p.grad = g
348
+
349
+ param_to_state, ordered_params = self.init_state_and_assign_params(
350
+ params, group
351
+ )
352
+
353
+ def enqueue_gathers(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
+
358
+ def enqueue_computes(start_idx, chunk_size):
359
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
360
+ state = param_to_state[id(p)]
361
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
362
+
363
+ def enqueue_scatters(start_idx, chunk_size):
364
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
365
+ state = param_to_state[id(p)]
366
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
368
+
369
+ chunk_size = params[0].device_mesh.mesh.numel()
370
+
371
+ # Wait grad update
372
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
373
+
374
+ enqueue_gathers(0, chunk_size)
375
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
376
+ enqueue_computes(i, chunk_size)
377
+ enqueue_gathers(i + chunk_size, chunk_size)
378
+ enqueue_scatters(i, chunk_size)
379
+
380
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
381
+
382
+ def step(self, closure=None):
383
+ """Perform a single optimization step.
384
+
385
+ Args:
386
+ closure (Callable, optional): A closure that reevaluates the model
387
+ and returns the loss.
388
+ """
389
+ loss = None
390
+ if closure is not None:
391
+ with torch.enable_grad():
392
+ loss = closure()
393
+
394
+ for group in self.param_groups:
395
+ ############################
396
+ # Muon #
397
+ ############################
398
+
399
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
+ lr = group["lr"]
401
+ wd = group["wd"]
402
+ momentum = group["momentum"]
403
+
404
+ if isinstance(params[0].data, DTensor):
405
+ self.parallel(
406
+ params,
407
+ group,
408
+ lr=lr,
409
+ wd=wd,
410
+ momentum=momentum,
411
+ )
412
+ else:
413
+ self.base(
414
+ params,
415
+ group,
416
+ lr=lr,
417
+ wd=wd,
418
+ momentum=momentum,
419
+ )
420
+
421
+ ############################
422
+ # AdamW backup #
423
+ ############################
424
+
425
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
426
+ lr = group["lr"]
427
+ beta1, beta2 = group["adamw_betas"]
428
+ eps = group["adamw_eps"]
429
+ weight_decay = group["wd"]
430
+
431
+ for p in params:
432
+ g = p.grad
433
+ if g is None:
434
+ continue
435
+ state = self.state[p]
436
+ if "step" not in state:
437
+ state["step"] = 0
438
+ state["moment1"] = torch.zeros_like(g)
439
+ state["moment2"] = torch.zeros_like(g)
440
+ state["step"] += 1
441
+ step = state["step"]
442
+ buf1 = state["moment1"]
443
+ buf2 = state["moment2"]
444
+ buf1.lerp_(g, 1 - beta1)
445
+ buf2.lerp_(g.square(), 1 - beta2)
446
+
447
+ g = buf1 / (eps + buf2.sqrt())
448
+
449
+ bias_correction1 = 1 - beta1**step
450
+ bias_correction2 = 1 - beta2**step
451
+ scale = bias_correction1 / bias_correction2**0.5
452
+ p.data.mul_(1 - lr * weight_decay)
453
+ p.data.add_(g, alpha=-lr / scale)
454
+
455
+ return loss
build/{torch210-cxx11-cu128-x86_64-linux → torch26-cxx11-cu124-x86_64-linux/optimizer}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux → torch26-cxx11-cu124-x86_64-linux/optimizer}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/{torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0
3
- size 2004144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068
3
+ size 1824224
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ @torch.no_grad()
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # )
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ gather_event: torch.cuda.Event | None = None
52
+ compute_event: torch.cuda.Event | None = None
53
+
54
+
55
+ @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ if state.gathered_grad is not None:
74
+ raise RuntimeError(
75
+ "Gather event already exists, which should not happen."
76
+ )
77
+ state.gathered_grad = torch.cat(gather_list, dim=0)
78
+ state.gather_event = torch.cuda.Event()
79
+ state.gather_event.record()
80
+ else:
81
+ state.gathered_grad = None
82
+ state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
+
86
+
87
+ @torch.no_grad()
88
+ def _compute_u(state, steps, rank, compute_stream):
89
+ with torch.cuda.stream(compute_stream):
90
+ if rank == state.worker_rank:
91
+ if state.gather_event is None:
92
+ raise RuntimeError("Gather event must be set before compute.")
93
+ compute_stream.wait_event(state.gather_event)
94
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
95
+ state.computed_u = u
96
+ state.compute_event = torch.cuda.Event()
97
+ state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
+ else:
101
+ state.computed_u = None
102
+ state.compute_event = None
103
+
104
+
105
+ @torch.no_grad()
106
+ def _scatter(p, state, lr, wd, rank, comm_stream):
107
+ u = state.computed_u
108
+ mesh = p.device_mesh
109
+
110
+ with torch.cuda.stream(comm_stream):
111
+ if rank == state.worker_rank:
112
+ if state.compute_event is None:
113
+ raise RuntimeError("Compute event must be set before scatter.")
114
+ comm_stream.wait_event(state.compute_event)
115
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
+ else:
117
+ scatter_list = None
118
+
119
+ u = torch.empty_like(p.to_local())
120
+ torch.distributed.scatter(
121
+ u,
122
+ scatter_list=scatter_list,
123
+ src=state.worker_rank,
124
+ group=mesh.get_group(),
125
+ )
126
+ if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
+ u = DTensor.from_local(
130
+ u,
131
+ placements=p.placements,
132
+ device_mesh=mesh,
133
+ )
134
+ p.data.mul_(1 - lr * wd)
135
+ p.data.add_(u, alpha=-lr)
136
+
137
+
138
+ class Muon(torch.optim.Optimizer):
139
+ """
140
+ Muon - MomentUm Orthogonalized by Newton-schulz
141
+
142
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
143
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
144
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
145
+ the advantage that it can be stably run in bfloat16 on the GPU.
146
+
147
+ Some warnings:
148
+ - We believe this optimizer is unlikely to work well for training with small batch size.
149
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
150
+
151
+ Arguments:
152
+ muon_params: The parameters to be optimized by Muon.
153
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
154
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
155
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
156
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
157
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
158
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
159
+ adamw_lr: The learning rate for the internal AdamW.
160
+ adamw_betas: The betas for the internal AdamW.
161
+ adamw_eps: The epsilon for the internal AdamW.
162
+ adamw_wd: The weight decay for the internal AdamW.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ is_muon_func,
169
+ lr=1e-3,
170
+ momentum=0.95,
171
+ nesterov=True,
172
+ ns_steps=5,
173
+ adamw_wd=0.1,
174
+ adamw_betas=(0.9, 0.95),
175
+ adamw_eps=1e-8,
176
+ none_grad=True,
177
+ debug=False,
178
+ ):
179
+ defaults = dict(
180
+ lr=lr,
181
+ wd=adamw_wd,
182
+ momentum=momentum,
183
+ nesterov=nesterov,
184
+ ns_steps=ns_steps,
185
+ adamw_betas=adamw_betas,
186
+ adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
+ )
189
+
190
+ super().__init__(model.parameters(), defaults)
191
+ self.is_muon_func = is_muon_func
192
+ self.model = model
193
+
194
+ if not dist.is_initialized():
195
+ raise RuntimeError(
196
+ "Muon optimizer requires distributed training to be initialized."
197
+ )
198
+
199
+ self.rank = dist.get_rank()
200
+
201
+ self.comm_stream = torch.cuda.Stream()
202
+ self.compute_stream = torch.cuda.Stream()
203
+ self.debug = debug
204
+
205
+ def __setstate__(self, state):
206
+ # Sort parameters into those for which we will use Muon, and those for which we will not
207
+ super().__setstate__(state)
208
+ for name, p in self.model.named_parameters():
209
+ if self.is_muon_func(p, name):
210
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
211
+ assert p.ndim == 2, p.ndim
212
+ self.state[p]["use_muon"] = True
213
+ self.state[p]["orig_shape"] = p.shape
214
+ else:
215
+ # Do not use Muon for parameters in adamw_params
216
+ self.state[p]["use_muon"] = False
217
+
218
+ def _calc_flops(self, G, steps):
219
+ assert len(G.shape) == 2
220
+ M, N = G.shape
221
+ if M > N:
222
+ M, N = N, M
223
+
224
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
225
+
226
+ def adjust_lr_for_muon(self, lr, param_shape):
227
+ A, B = param_shape[:2]
228
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
229
+ # as describted in the paper
230
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
231
+ adjusted_lr = lr * adjusted_ratio
232
+ return adjusted_lr
233
+
234
+ def init_state_and_assign_params(self, params, group):
235
+ param_to_state = {}
236
+ param_to_flops = {}
237
+
238
+ total_flops = 0
239
+ for p in params:
240
+ g = p.grad
241
+ if g is None:
242
+ continue
243
+ assert g.ndim == 2, "Muon only supports 2D parameters."
244
+
245
+ flops = self._calc_flops(g, group["ns_steps"])
246
+ param_to_flops[id(p)] = flops
247
+ total_flops += flops
248
+
249
+ if self.debug:
250
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
251
+
252
+ ordered_params = sorted(
253
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
254
+ )
255
+
256
+ round_robin = 0
257
+ mesh = None
258
+ for p in ordered_params:
259
+ if mesh is None:
260
+ mesh = p.device_mesh
261
+ if mesh.ndim != 1:
262
+ raise NotImplementedError(
263
+ "Muon requires a 1D mesh for distributed training yet."
264
+ )
265
+ elif mesh != p.device_mesh:
266
+ raise ValueError("All parameters must be on the same mesh.")
267
+
268
+ param_to_state[id(p)] = _muon_state()
269
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
270
+
271
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
272
+
273
+ return param_to_state, ordered_params
274
+
275
+ def base(self, params, group, lr, wd, momentum):
276
+ # generate weight updates in distributed fashion
277
+ for p in params:
278
+ g = p.grad
279
+ if g is None:
280
+ continue
281
+ if g.ndim > 2:
282
+ g = g.view(g.size(0), -1)
283
+ assert g is not None
284
+
285
+ # calc update
286
+ state = self.state[p]
287
+ if "momentum_buffer" not in state:
288
+ state["momentum_buffer"] = torch.zeros_like(g)
289
+ buf = state["momentum_buffer"]
290
+ buf.mul_(momentum).add_(g)
291
+ if group["nesterov"]:
292
+ g = g.add(buf, alpha=momentum)
293
+ else:
294
+ g = buf
295
+
296
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
297
+
298
+ # scale update
299
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
+
301
+ # apply weight decay
302
+ p.data.mul_(1 - lr * wd)
303
+
304
+ # apply update
305
+ p.data.add_(u, alpha=-adjusted_lr)
306
+
307
+ def _update_g(self, p, g, group, momentum):
308
+ # calc update
309
+ state = self.state[p]
310
+ if "momentum_buffer" not in state:
311
+ state["momentum_buffer"] = torch.zeros_like(g)
312
+ buf = state["momentum_buffer"]
313
+ buf.mul_(momentum).add_(g)
314
+ if group["nesterov"]:
315
+ g = g.add(buf, alpha=momentum)
316
+ else:
317
+ g = buf
318
+ return g
319
+
320
+ def _update_p(self, p, u, lr, wd):
321
+ # scale update
322
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
+ # apply weight decay
324
+ p.data.mul_(1 - lr * wd)
325
+ # apply update
326
+ p.data.add_(u, alpha=-adjusted_lr)
327
+
328
+ def parallel(self, params, group, lr, wd, momentum):
329
+ """
330
+ Perform a parallel optimization step using Muon.
331
+ """
332
+
333
+ for p in params:
334
+ g = p.grad
335
+ if g is None:
336
+ continue
337
+ if g.ndim > 2:
338
+ g = g.view(g.size(0), -1)
339
+
340
+ # Update g in the local rank
341
+ g = self._update_g(
342
+ p,
343
+ g,
344
+ group,
345
+ momentum=momentum,
346
+ )
347
+ p.grad = g
348
+
349
+ param_to_state, ordered_params = self.init_state_and_assign_params(
350
+ params, group
351
+ )
352
+
353
+ def enqueue_gathers(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
+
358
+ def enqueue_computes(start_idx, chunk_size):
359
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
360
+ state = param_to_state[id(p)]
361
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
362
+
363
+ def enqueue_scatters(start_idx, chunk_size):
364
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
365
+ state = param_to_state[id(p)]
366
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
368
+
369
+ chunk_size = params[0].device_mesh.mesh.numel()
370
+
371
+ # Wait grad update
372
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
373
+
374
+ enqueue_gathers(0, chunk_size)
375
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
376
+ enqueue_computes(i, chunk_size)
377
+ enqueue_gathers(i + chunk_size, chunk_size)
378
+ enqueue_scatters(i, chunk_size)
379
+
380
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
381
+
382
+ def step(self, closure=None):
383
+ """Perform a single optimization step.
384
+
385
+ Args:
386
+ closure (Callable, optional): A closure that reevaluates the model
387
+ and returns the loss.
388
+ """
389
+ loss = None
390
+ if closure is not None:
391
+ with torch.enable_grad():
392
+ loss = closure()
393
+
394
+ for group in self.param_groups:
395
+ ############################
396
+ # Muon #
397
+ ############################
398
+
399
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
+ lr = group["lr"]
401
+ wd = group["wd"]
402
+ momentum = group["momentum"]
403
+
404
+ if isinstance(params[0].data, DTensor):
405
+ self.parallel(
406
+ params,
407
+ group,
408
+ lr=lr,
409
+ wd=wd,
410
+ momentum=momentum,
411
+ )
412
+ else:
413
+ self.base(
414
+ params,
415
+ group,
416
+ lr=lr,
417
+ wd=wd,
418
+ momentum=momentum,
419
+ )
420
+
421
+ ############################
422
+ # AdamW backup #
423
+ ############################
424
+
425
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
426
+ lr = group["lr"]
427
+ beta1, beta2 = group["adamw_betas"]
428
+ eps = group["adamw_eps"]
429
+ weight_decay = group["wd"]
430
+
431
+ for p in params:
432
+ g = p.grad
433
+ if g is None:
434
+ continue
435
+ state = self.state[p]
436
+ if "step" not in state:
437
+ state["step"] = 0
438
+ state["moment1"] = torch.zeros_like(g)
439
+ state["moment2"] = torch.zeros_like(g)
440
+ state["step"] += 1
441
+ step = state["step"]
442
+ buf1 = state["moment1"]
443
+ buf2 = state["moment2"]
444
+ buf1.lerp_(g, 1 - beta1)
445
+ buf2.lerp_(g.square(), 1 - beta2)
446
+
447
+ g = buf1 / (eps + buf2.sqrt())
448
+
449
+ bias_correction1 = 1 - beta1**step
450
+ bias_correction2 = 1 - beta2**step
451
+ scale = bias_correction1 / bias_correction2**0.5
452
+ p.data.mul_(1 - lr * weight_decay)
453
+ p.data.add_(g, alpha=-lr / scale)
454
+
455
+ return loss
build/{torch210-cxx11-cu130-x86_64-linux → torch26-cxx11-cu126-x86_64-linux/optimizer}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu130-x86_64-linux → torch26-cxx11-cu126-x86_64-linux/optimizer}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/{torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2
3
- size 2004728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577
3
+ size 1824224
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # This code snippet is a modified version adapted from the following GitHub repositories:
10
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
+ @torch.no_grad()
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # )
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ gather_event: torch.cuda.Event | None = None
52
+ compute_event: torch.cuda.Event | None = None
53
+
54
+
55
+ @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ if state.gathered_grad is not None:
74
+ raise RuntimeError(
75
+ "Gather event already exists, which should not happen."
76
+ )
77
+ state.gathered_grad = torch.cat(gather_list, dim=0)
78
+ state.gather_event = torch.cuda.Event()
79
+ state.gather_event.record()
80
+ else:
81
+ state.gathered_grad = None
82
+ state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
+
86
+
87
+ @torch.no_grad()
88
+ def _compute_u(state, steps, rank, compute_stream):
89
+ with torch.cuda.stream(compute_stream):
90
+ if rank == state.worker_rank:
91
+ if state.gather_event is None:
92
+ raise RuntimeError("Gather event must be set before compute.")
93
+ compute_stream.wait_event(state.gather_event)
94
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
95
+ state.computed_u = u
96
+ state.compute_event = torch.cuda.Event()
97
+ state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
+ else:
101
+ state.computed_u = None
102
+ state.compute_event = None
103
+
104
+
105
+ @torch.no_grad()
106
+ def _scatter(p, state, lr, wd, rank, comm_stream):
107
+ u = state.computed_u
108
+ mesh = p.device_mesh
109
+
110
+ with torch.cuda.stream(comm_stream):
111
+ if rank == state.worker_rank:
112
+ if state.compute_event is None:
113
+ raise RuntimeError("Compute event must be set before scatter.")
114
+ comm_stream.wait_event(state.compute_event)
115
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
+ else:
117
+ scatter_list = None
118
+
119
+ u = torch.empty_like(p.to_local())
120
+ torch.distributed.scatter(
121
+ u,
122
+ scatter_list=scatter_list,
123
+ src=state.worker_rank,
124
+ group=mesh.get_group(),
125
+ )
126
+ if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
+ u = DTensor.from_local(
130
+ u,
131
+ placements=p.placements,
132
+ device_mesh=mesh,
133
+ )
134
+ p.data.mul_(1 - lr * wd)
135
+ p.data.add_(u, alpha=-lr)
136
+
137
+
138
+ class Muon(torch.optim.Optimizer):
139
+ """
140
+ Muon - MomentUm Orthogonalized by Newton-schulz
141
+
142
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
143
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
144
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
145
+ the advantage that it can be stably run in bfloat16 on the GPU.
146
+
147
+ Some warnings:
148
+ - We believe this optimizer is unlikely to work well for training with small batch size.
149
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
150
+
151
+ Arguments:
152
+ muon_params: The parameters to be optimized by Muon.
153
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
154
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
155
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
156
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
157
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
158
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
159
+ adamw_lr: The learning rate for the internal AdamW.
160
+ adamw_betas: The betas for the internal AdamW.
161
+ adamw_eps: The epsilon for the internal AdamW.
162
+ adamw_wd: The weight decay for the internal AdamW.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ is_muon_func,
169
+ lr=1e-3,
170
+ momentum=0.95,
171
+ nesterov=True,
172
+ ns_steps=5,
173
+ adamw_wd=0.1,
174
+ adamw_betas=(0.9, 0.95),
175
+ adamw_eps=1e-8,
176
+ none_grad=True,
177
+ debug=False,
178
+ ):
179
+ defaults = dict(
180
+ lr=lr,
181
+ wd=adamw_wd,
182
+ momentum=momentum,
183
+ nesterov=nesterov,
184
+ ns_steps=ns_steps,
185
+ adamw_betas=adamw_betas,
186
+ adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
+ )
189
+
190
+ super().__init__(model.parameters(), defaults)
191
+ self.is_muon_func = is_muon_func
192
+ self.model = model
193
+
194
+ if not dist.is_initialized():
195
+ raise RuntimeError(
196
+ "Muon optimizer requires distributed training to be initialized."
197
+ )
198
+
199
+ self.rank = dist.get_rank()
200
+
201
+ self.comm_stream = torch.cuda.Stream()
202
+ self.compute_stream = torch.cuda.Stream()
203
+ self.debug = debug
204
+
205
+ def __setstate__(self, state):
206
+ # Sort parameters into those for which we will use Muon, and those for which we will not
207
+ super().__setstate__(state)
208
+ for name, p in self.model.named_parameters():
209
+ if self.is_muon_func(p, name):
210
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
211
+ assert p.ndim == 2, p.ndim
212
+ self.state[p]["use_muon"] = True
213
+ self.state[p]["orig_shape"] = p.shape
214
+ else:
215
+ # Do not use Muon for parameters in adamw_params
216
+ self.state[p]["use_muon"] = False
217
+
218
+ def _calc_flops(self, G, steps):
219
+ assert len(G.shape) == 2
220
+ M, N = G.shape
221
+ if M > N:
222
+ M, N = N, M
223
+
224
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
225
+
226
+ def adjust_lr_for_muon(self, lr, param_shape):
227
+ A, B = param_shape[:2]
228
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
229
+ # as describted in the paper
230
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
231
+ adjusted_lr = lr * adjusted_ratio
232
+ return adjusted_lr
233
+
234
+ def init_state_and_assign_params(self, params, group):
235
+ param_to_state = {}
236
+ param_to_flops = {}
237
+
238
+ total_flops = 0
239
+ for p in params:
240
+ g = p.grad
241
+ if g is None:
242
+ continue
243
+ assert g.ndim == 2, "Muon only supports 2D parameters."
244
+
245
+ flops = self._calc_flops(g, group["ns_steps"])
246
+ param_to_flops[id(p)] = flops
247
+ total_flops += flops
248
+
249
+ if self.debug:
250
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
251
+
252
+ ordered_params = sorted(
253
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
254
+ )
255
+
256
+ round_robin = 0
257
+ mesh = None
258
+ for p in ordered_params:
259
+ if mesh is None:
260
+ mesh = p.device_mesh
261
+ if mesh.ndim != 1:
262
+ raise NotImplementedError(
263
+ "Muon requires a 1D mesh for distributed training yet."
264
+ )
265
+ elif mesh != p.device_mesh:
266
+ raise ValueError("All parameters must be on the same mesh.")
267
+
268
+ param_to_state[id(p)] = _muon_state()
269
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
270
+
271
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
272
+
273
+ return param_to_state, ordered_params
274
+
275
+ def base(self, params, group, lr, wd, momentum):
276
+ # generate weight updates in distributed fashion
277
+ for p in params:
278
+ g = p.grad
279
+ if g is None:
280
+ continue
281
+ if g.ndim > 2:
282
+ g = g.view(g.size(0), -1)
283
+ assert g is not None
284
+
285
+ # calc update
286
+ state = self.state[p]
287
+ if "momentum_buffer" not in state:
288
+ state["momentum_buffer"] = torch.zeros_like(g)
289
+ buf = state["momentum_buffer"]
290
+ buf.mul_(momentum).add_(g)
291
+ if group["nesterov"]:
292
+ g = g.add(buf, alpha=momentum)
293
+ else:
294
+ g = buf
295
+
296
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
297
+
298
+ # scale update
299
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
+
301
+ # apply weight decay
302
+ p.data.mul_(1 - lr * wd)
303
+
304
+ # apply update
305
+ p.data.add_(u, alpha=-adjusted_lr)
306
+
307
+ def _update_g(self, p, g, group, momentum):
308
+ # calc update
309
+ state = self.state[p]
310
+ if "momentum_buffer" not in state:
311
+ state["momentum_buffer"] = torch.zeros_like(g)
312
+ buf = state["momentum_buffer"]
313
+ buf.mul_(momentum).add_(g)
314
+ if group["nesterov"]:
315
+ g = g.add(buf, alpha=momentum)
316
+ else:
317
+ g = buf
318
+ return g
319
+
320
+ def _update_p(self, p, u, lr, wd):
321
+ # scale update
322
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
+ # apply weight decay
324
+ p.data.mul_(1 - lr * wd)
325
+ # apply update
326
+ p.data.add_(u, alpha=-adjusted_lr)
327
+
328
+ def parallel(self, params, group, lr, wd, momentum):
329
+ """
330
+ Perform a parallel optimization step using Muon.
331
+ """
332
+
333
+ for p in params:
334
+ g = p.grad
335
+ if g is None:
336
+ continue
337
+ if g.ndim > 2:
338
+ g = g.view(g.size(0), -1)
339
+
340
+ # Update g in the local rank
341
+ g = self._update_g(
342
+ p,
343
+ g,
344
+ group,
345
+ momentum=momentum,
346
+ )
347
+ p.grad = g
348
+
349
+ param_to_state, ordered_params = self.init_state_and_assign_params(
350
+ params, group
351
+ )
352
+
353
+ def enqueue_gathers(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
+
358
+ def enqueue_computes(start_idx, chunk_size):
359
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
360
+ state = param_to_state[id(p)]
361
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
362
+
363
+ def enqueue_scatters(start_idx, chunk_size):
364
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
365
+ state = param_to_state[id(p)]
366
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
368
+
369
+ chunk_size = params[0].device_mesh.mesh.numel()
370
+
371
+ # Wait grad update
372
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
373
+
374
+ enqueue_gathers(0, chunk_size)
375
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
376
+ enqueue_computes(i, chunk_size)
377
+ enqueue_gathers(i + chunk_size, chunk_size)
378
+ enqueue_scatters(i, chunk_size)
379
+
380
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
381
+
382
+ def step(self, closure=None):
383
+ """Perform a single optimization step.
384
+
385
+ Args:
386
+ closure (Callable, optional): A closure that reevaluates the model
387
+ and returns the loss.
388
+ """
389
+ loss = None
390
+ if closure is not None:
391
+ with torch.enable_grad():
392
+ loss = closure()
393
+
394
+ for group in self.param_groups:
395
+ ############################
396
+ # Muon #
397
+ ############################
398
+
399
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
+ lr = group["lr"]
401
+ wd = group["wd"]
402
+ momentum = group["momentum"]
403
+
404
+ if isinstance(params[0].data, DTensor):
405
+ self.parallel(
406
+ params,
407
+ group,
408
+ lr=lr,
409
+ wd=wd,
410
+ momentum=momentum,
411
+ )
412
+ else:
413
+ self.base(
414
+ params,
415
+ group,
416
+ lr=lr,
417
+ wd=wd,
418
+ momentum=momentum,
419
+ )
420
+
421
+ ############################
422
+ # AdamW backup #
423
+ ############################
424
+
425
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
426
+ lr = group["lr"]
427
+ beta1, beta2 = group["adamw_betas"]
428
+ eps = group["adamw_eps"]
429
+ weight_decay = group["wd"]
430
+
431
+ for p in params:
432
+ g = p.grad
433
+ if g is None:
434
+ continue
435
+ state = self.state[p]
436
+ if "step" not in state:
437
+ state["step"] = 0
438
+ state["moment1"] = torch.zeros_like(g)
439
+ state["moment2"] = torch.zeros_like(g)
440
+ state["step"] += 1
441
+ step = state["step"]
442
+ buf1 = state["moment1"]
443
+ buf2 = state["moment2"]
444
+ buf1.lerp_(g, 1 - beta1)
445
+ buf2.lerp_(g.square(), 1 - beta2)
446
+
447
+ g = buf1 / (eps + buf2.sqrt())
448
+
449
+ bias_correction1 = 1 - beta1**step
450
+ bias_correction2 = 1 - beta2**step
451
+ scale = bias_correction1 / bias_correction2**0.5
452
+ p.data.mul_(1 - lr * weight_decay)
453
+ p.data.add_(g, alpha=-lr / scale)
454
+
455
+ return loss
build/{torch210-cxx11-rocm70-x86_64-linux → torch26-cxx11-rocm62-x86_64-linux/optimizer}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu128-x86_64-linux → torch26-cxx11-rocm62-x86_64-linux/optimizer}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_06a260a_dirty
3
- ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_06a260a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/{torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497
3
- size 1866400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd
3
+ size 1749744