diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml deleted file mode 100644 index b0d95d88e34537bfddd380dd054393190e7cebca..0000000000000000000000000000000000000000 --- a/.github/actionlint.yaml +++ /dev/null @@ -1,3 +0,0 @@ -self-hosted-runner: - labels: - - docker-builder-01 diff --git a/.github/workflows/build-and-commit.yml b/.github/workflows/build-and-commit.yml deleted file mode 100644 index 2b8d63d05640696ccc4fc70d3e2fcf0b35d74044..0000000000000000000000000000000000000000 --- a/.github/workflows/build-and-commit.yml +++ /dev/null @@ -1,120 +0,0 @@ -name: Nix build and commit - -on: - pull_request: - types: [opened, synchronize, reopened] - workflow_dispatch: - -permissions: - contents: write - -jobs: - check-commit: - runs-on: ubuntu-latest - outputs: - skip: ${{ steps.check.outputs.skip }} - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - id: check - run: | - if [ "${{ github.event_name }}" = "pull_request" ]; then - msg=$(git log -1 --pretty=%B "${{ github.event.pull_request.head.sha }}") - else - msg="manual dispatch" - fi - echo "Commit message: $msg" - if echo "$msg" | grep -q '\[skip-build\]'; then - echo "skip=true" >> "$GITHUB_OUTPUT" - else - echo "skip=false" >> "$GITHUB_OUTPUT" - fi - - build_and_commit: - needs: check-commit - if: needs.check-commit.outputs.skip == 'false' - runs-on: docker-builder-01 - steps: - - name: Show disk usage - run: df -h - - - name: Notify build start on Slack - id: slack_start - run: | - msg="*Build started* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>" - response=$(curl -s -X POST \ - -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \ - -H "Content-type: application/json; charset=utf-8" \ - --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\"}" \ - https://slack.com/api/chat.postMessage) - ts=$(echo "$response" | jq -r '.ts') - echo "thread_ts=$ts" >> "$GITHUB_OUTPUT" - echo "$response" - - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - lfs: true - ref: ${{ github.head_ref || github.ref }} - - - name: Install Nix - uses: cachix/install-nix-action@v31 - - - name: Setup huggingface cachix - uses: cachix/cachix-action@v15 - with: - name: huggingface - - - name: Clean build directory - run: | - rm -rf build - - - name: Build with Nix - run: | - nix run .#build-and-copy \ - --override-input kernel-builder github:huggingface/kernel-builder \ - --max-jobs 8 \ - -j 8 \ - -L - - - name: List built binaries - run: | - ls build - - - name: Commit build artifact - run: | - git config user.name "github-actions[bot]" - git config user.email "41898282+github-actions[bot]@users.noreply.github.com" - git add build/* - git commit -m "Add built binary [skip-build]" - - - name: Push changes - run: | - git push origin HEAD:"$HEAD_REF" - env: - HEAD_REF: ${{ github.head_ref || github.ref }} - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Notify success on Slack (thread) - if: success() - run: | - ts="${{ steps.slack_start.outputs.thread_ts }}" - msg="*Build succeeded* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>" - curl -s -X POST \ - -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \ - -H "Content-type: application/json; charset=utf-8" \ - --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \ - https://slack.com/api/chat.postMessage - - - name: Notify failure on Slack (thread) - if: failure() - run: | - ts="${{ steps.slack_start.outputs.thread_ts }}" - msg="*Build failed* for \`${{ github.repository }}\`\nBranch: \`${{ github.ref_name }}\`\n<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Workflow>" - curl -s -X POST \ - -H "Authorization: Bearer ${{ secrets.SLACK_TOKEN }}" \ - -H "Content-type: application/json; charset=utf-8" \ - --data "{\"channel\":\"${{ secrets.SLACK_CHANNEL_ID }}\",\"text\":\"$msg\",\"thread_ts\":\"$ts\"}" \ - https://slack.com/api/chat.postMessage diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index 6d2ff056b79fbd65f2ed61018c5730b234fee9b7..0000000000000000000000000000000000000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: pre-commit - -on: - pull_request: - push: - branches: [ main, master ] - -jobs: - run-pre-commit: - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Cache pre-commit - uses: actions/cache@v4 - with: - path: ~/.cache/pre-commit - key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }} - restore-keys: | - pre-commit-${{ runner.os }}- - - - name: Run pre-commit - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/push-to-hf.yml b/.github/workflows/push-to-hf.yml deleted file mode 100644 index ae31d9649bab120a948956ea6cf422dfa33b74a8..0000000000000000000000000000000000000000 --- a/.github/workflows/push-to-hf.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Push to HF Repo - -on: - push: - branches: - - main - workflow_dispatch: - -jobs: - push_to_hf: - runs-on: ubuntu-latest - steps: - # 1. Checkout the repo - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Install Git LFS - run: | - git lfs install - git lfs fetch --all - git lfs pull - # 2. Set up Git - - name: Configure Git - run: | - git config user.name "MotifTech" - git config user.email "huggingface@motiftech.io" - - # 3. Add HF remote - - name: Add Hugging Face remote - run: | - git remote add hf https://huggingface.co/Motif-Technologies/optimizer - git fetch hf || true - - # 4. Push to HF repo - - name: Push to Hugging Face - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - git push "https://hf_token:${HF_TOKEN}@huggingface.co/Motif-Technologies/optimizer" HEAD:main diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 3c818f24455eb15948604ef3c9a32c5e351cb27c..0000000000000000000000000000000000000000 --- a/.gitignore +++ /dev/null @@ -1,21 +0,0 @@ -__pycache__ -.idea -.DS_Store -*.egg-info -outputs -dist/* -.vscode - -# data -data -out -wandb - -torchtitan/datasets/**/*.model -torchtitan/experiments/flux/assets/* - -# temp files -*.log -error.json -_remote_module_non_scriptable.py -.git_disabled/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 55f8e34c04aac06db5a3137a475e13e3e5ecf8d5..0000000000000000000000000000000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -default_install_hook_types: - - pre-commit - - commit-msg -default_stages: - - pre-commit # Run locally - - manual # Run in CI -exclude: '(build|result)/.*|__pycache__/.*|.*\.(png|html)$' -repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] -- repo: https://github.com/crate-ci/typos - rev: v1.34.0 - hooks: - - id: typos - exclude: '.gitattributes' -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v20.1.3 - hooks: - - id: clang-format - types_or: [c++, cuda] - args: [--style=file, --verbose] -- repo: https://github.com/jackdewinter/pymarkdown - rev: v0.9.29 - hooks: - - id: pymarkdown - args: [fix] diff --git a/README.md b/README.md index 59d1c3567a56a6978c0a714c956cf845e443fda8..e24a75af8cfd719a94a499a644188b3164b2d1cb 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ --- tags: -- kernels -license: apache-2.0 +- kernel --- # Optimizer @@ -10,14 +9,8 @@ Optimizer is a python package that provides: - PyTorch implementation of recent optimizer algorithms - with support for parallelism techniques for efficient large-scale training. -## Currently implemented -- Parallel Muon with N-D sharding - - [arxiv URL](https://arxiv.org/abs/2511.07464) - - Supports **general N-D sharding configurations** - - The implementation is not tied to any specific parallel strategy. - - Verified from basic FSDP2 setups up to hybrid configurations such as - **(2 TP + 2 DP-Replicate + 2 DP-Shard)**. - - Verified configurations can be found in [test_muon.py](./test/test_muon.py) +### Currently implemented +- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf) ## Usage @@ -27,72 +20,14 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from kernels import get_kernel optimizer = get_kernel("motif-technologies/optimizer") -get_default_muon_param_groups = optimizer.muon.get_default_muon_param_groups model = None # your model here fsdp_model = FSDP(model) -# muon, in nature, cannot use 1-d tensor -# we provide helper function to group such tensors -# you can use your own function, if necessary -params = get_default_muon_param_groups(model) # user can write own is_muon_func, if necessary - optim = optimizer.Muon( - params, + fsdp_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4, ) ``` - -## Test -- Check [test/README.md](./test/README.md) for how to run the tests. - -## Pre-commit Hooks - -This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits. - -### Setup - -1. Install pre-commit: - - ```bash - pip install pre-commit - ``` - -2. Install the git hooks: - -```bash - pre-commit install - ``` - -Once installed, the configured hooks will run automatically on each commit. - -### Included Hooks - -The following tools are run via pre-commit: - -- **[yapf](https://github.com/google/yapf)** – Python code formatter -- **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos -- **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports -- **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`) -- **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files -- **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows - -### Usage - -- Run all checks on the entire codebase: - - ```bash - pre-commit run --all-files - ``` - -- Run a specific hook (example: isort): - - ```bash - pre-commit run isort --all-files - ``` - -### Test - -- There is a [simple unittest for Parallel Muon](./test/test_muon/README.md) diff --git a/build.toml b/build.toml index ebabc676bfe40eb07e2bb447ff0c17605ac42844..b80854db0a67cdde4e5c3dcb8d95f18704812383 100644 --- a/build.toml +++ b/build.toml @@ -1,33 +1,23 @@ [general] name = "optimizer" -backends = [ - "cuda", - "rocm", -] +universal = false [torch] src = [ - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h", + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", ] -[kernel.optimizer] -backend = "cuda" -depends = ["torch"] -src = ["optimizer/dummy.cu"] - -[kernel.optimizer_rocm] +[kernel.activation] backend = "rocm" -rocm-archs = [ - "gfx906", - "gfx908", - "gfx90a", - "gfx940", - "gfx941", - "gfx942", - "gfx1030", - "gfx1100", - "gfx1101", +src = [ + "optimizer/dummy.cu", +] +depends = [ "torch" ] + +[kernel.activation_cuda] +backend = "cuda" +src = [ + "optimizer/dummy.cu", ] -depends = ["torch"] -src = ["optimizer/dummy.cu"] +depends = [ "torch" ] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 6015e5b4ea5da27e0002b298d9a1ab55142f88ab..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf -size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/muon.py b/build/torch210-cxx11-cu126-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a2b4992c68bd2d564fa8ac804bce7a9f9d0a09d9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0 -size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/muon.py b/build/torch210-cxx11-cu128-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 62bbc727da9606819a23c43dda20add2be7c1fe3..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2 -size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/muon.py b/build/torch210-cxx11-cu130-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a2bbc913106abe6d784d7634ad119d969ff23a3c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497 -size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index ed70a8ee48aca9da47db195b5e73c86aca32b153..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d804ba4d3ed9716c80e9819ba16a2bef300fb23fa4c456c550f4a96167a2eb00 -size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch210-cxx11-cu126-x86_64-linux/__init__.py rename to build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4df57b1d0e99209ec349328d9fa3a61cea7f97da --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c +size 1787272 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch210-cxx11-cu128-x86_64-linux/__init__.py rename to build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..078af88758fab815617b6fae432c3ce4d18f6271 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068 +size 1824224 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch210-cxx11-cu130-x86_64-linux/__init__.py rename to build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d8368d3904638b07e920f72204987d1820114e0a --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577 +size 1824224 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch210-cxx11-rocm70-x86_64-linux/__init__.py rename to build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..873c79d168145a8e956f609a3be376e9f1817b41 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd +size 1749744 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch210-cxx11-rocm71-x86_64-linux/__init__.py rename to build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..65af7b83eba70e3218649b947d359fda84b41be0 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:579e9ddf66a4f17ead9232c2f32e6327fe6a3f16dd235e2e73e6cb282de1797e +size 1787192 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch28-cxx11-cu126-x86_64-linux/__init__.py rename to build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c1a231bc7234f19a04d1c05f491511c3d21ebaca --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beacb4ba2d56463b6d444875728b3462cb3ff6c1449e3c9693cd665bfbbbbb73 +size 1824184 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from build/torch28-cxx11-cu128-x86_64-linux/__init__.py rename to build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf68eab4638da3512b5e49541c916ebd12301f0 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ac6d1fa4c9a0f286c2b7597d1f2f337aff89a570 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b04b011803d328d8dcd2edcf4c3840ddbb1bb2f093464c208f0ba2faf4f16bc +size 1824184 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..0d614d55d721efac406c147b4f62e6c703a91107 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py @@ -0,0 +1,455 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.no_grad() +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +@torch.no_grad() +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # Clear u to free memory + state.computed_u = None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py old mode 100644 new mode 100755 index cb9efd677b388ebc299d6c4747eee701c96211f6..7cf68eab4638da3512b5e49541c916ebd12301f0 --- a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_b0230e7_dirty -ops = torch.ops._optimizer_b0230e7_dirty +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_b0230e7_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..72737f2570d97f57b8f15140e6fd9e12623152f2 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad6c725009f2e776b99d3134c75f15e11dd7fe75fe4ba1fa94779018c7871f8c +size 1787368 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so deleted file mode 100755 index ac384d7194105cc1fd531bed4212a63de4d9be00..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:236bb0d67cbb2718b076637569923cf240de1c7a074790623ecb9c049fca9732 -size 1787368 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py old mode 100644 new mode 100755 index 4af25d55c528fb0db2272540838ab70eb3619194..0d614d55d721efac406c147b4f62e6c703a91107 --- a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py @@ -1,26 +1,14 @@ -import logging import math -import types from dataclasses import dataclass -from typing import List, Optional, Union, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard - -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 +from torch.distributed._tensor import DTensor # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -32,31 +20,26 @@ def _zeropower_via_newtonschulz5(G, steps): performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE + a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast - if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + X = X.bfloat16() # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T - return X + return X.to(G.dtype) @dataclass @@ -64,425 +47,92 @@ class _muon_state: # TODO: use Optional worker_rank: int | None = None gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None - - -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols @torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) - assert g.numel() == shard_elems - per_dst[dst].append(g) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - write_offsets = {id(p): 0 for p in owned_params} - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) - dst.copy_(sg) - - write_offsets[id(p)] += n - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None @torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ +def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) - - offset = 0 - for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) - assert n > 0 - - su = u_full.narrow(0, offset, n) - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += split_elems_for_src(p, rank, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = split_elems_for_src(p, rank, num_ranks) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - if rank == state.worker_rank: - # Free computed_u + # Clear u to free memory state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales(p, state.qk_clip_state) - if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: Optional[torch.Tensor] + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) class Muon(torch.optim.Optimizer): @@ -499,87 +149,71 @@ class Muon(torch.optim.Optimizer): - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + adamw_wd: The weight decay for the internal AdamW. """ - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - overlap_step=5): + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): defaults = dict( lr=lr, - weight_decay=weight_decay, + wd=adamw_wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, - use_muon=True, ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model - super().__init__(params, defaults) + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) - self.rank = None + self.rank = dist.get_rank() self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config - self.overlap_step = overlap_step + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,30 +231,7 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - if p.placements == (Shard(dim=0), ): - # Case for FSDP - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") - - def init_state_and_assign_params(self, names, params, group, qk_logits): + def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} @@ -636,44 +247,34 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) round_robin = 0 mesh = None - shard_mesh = None - process_group = None - for n, p in zip(ordered_names, ordered_params): + for p in ordered_params: if mesh is None: mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) elif mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) + param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def base(self, params, group, lr, wd, momentum): # generate weight updates in distributed fashion - for n, p in zip(names, params): + for p in params: g = p.grad if g is None: continue @@ -692,87 +293,39 @@ class Muon(torch.optim.Optimizer): else: g = buf - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + # apply weight decay + p.data.mul_(1 - lr * wd) - scales_full = self._compute_scales(p, qk_clip_state) - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + # apply update + p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + g = g.add(buf, alpha=momentum) + else: + g = buf + return g - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay - p.data.mul_(1 - lr * weight_decay) + p.data.mul_(1 - lr * wd) # apply update p.data.add_(u, alpha=-adjusted_lr) - def get_qk_clip_info(self, n, qk_logits): - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, params, group, lr, wd, momentum): """ Perform a parallel optimization step using Muon. """ @@ -794,143 +347,44 @@ class Muon(torch.optim.Optimizer): p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None + params, group + ) - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + chunk_size = params[0].device_mesh.mesh.numel() # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - + enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: Union[float, torch.Tensor], - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) - def step(self, closure=None, qk_logits=None): + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None if closure is not None: @@ -938,127 +392,64 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, + self.base( + params, + group, lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, + wd=wd, + momentum=momentum, ) + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + return loss diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py old mode 100644 new mode 100755 index cb9efd677b388ebc299d6c4747eee701c96211f6..7cf68eab4638da3512b5e49541c916ebd12301f0 --- a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_b0230e7_dirty -ops = torch.ops._optimizer_b0230e7_dirty +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_b0230e7_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6c71757d9ed8f7ccfae12cd8eb6837ff15c9f773 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50cb5819ff08a2179d78cd98164d07fd3cef1b66ee7703d599a310dfb140b9d1 +size 1824256 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so deleted file mode 100755 index 7ccc064e07fe7031403165bb3c78e31e84cbdf19..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0 -size 1824256 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py old mode 100644 new mode 100755 index 4af25d55c528fb0db2272540838ab70eb3619194..0d614d55d721efac406c147b4f62e6c703a91107 --- a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -1,26 +1,14 @@ -import logging import math -import types from dataclasses import dataclass -from typing import List, Optional, Union, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard - -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 +from torch.distributed._tensor import DTensor # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -32,31 +20,26 @@ def _zeropower_via_newtonschulz5(G, steps): performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE + a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast - if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + X = X.bfloat16() # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T - return X + return X.to(G.dtype) @dataclass @@ -64,425 +47,92 @@ class _muon_state: # TODO: use Optional worker_rank: int | None = None gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None - - -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols @torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) - assert g.numel() == shard_elems - per_dst[dst].append(g) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - write_offsets = {id(p): 0 for p in owned_params} - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) - dst.copy_(sg) - - write_offsets[id(p)] += n - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None @torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ +def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) - - offset = 0 - for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) - assert n > 0 - - su = u_full.narrow(0, offset, n) - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += split_elems_for_src(p, rank, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = split_elems_for_src(p, rank, num_ranks) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - if rank == state.worker_rank: - # Free computed_u + # Clear u to free memory state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales(p, state.qk_clip_state) - if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: Optional[torch.Tensor] + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) class Muon(torch.optim.Optimizer): @@ -499,87 +149,71 @@ class Muon(torch.optim.Optimizer): - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + adamw_wd: The weight decay for the internal AdamW. """ - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - overlap_step=5): + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): defaults = dict( lr=lr, - weight_decay=weight_decay, + wd=adamw_wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, - use_muon=True, ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model - super().__init__(params, defaults) + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) - self.rank = None + self.rank = dist.get_rank() self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config - self.overlap_step = overlap_step + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,30 +231,7 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - if p.placements == (Shard(dim=0), ): - # Case for FSDP - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") - - def init_state_and_assign_params(self, names, params, group, qk_logits): + def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} @@ -636,44 +247,34 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) round_robin = 0 mesh = None - shard_mesh = None - process_group = None - for n, p in zip(ordered_names, ordered_params): + for p in ordered_params: if mesh is None: mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) elif mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) + param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def base(self, params, group, lr, wd, momentum): # generate weight updates in distributed fashion - for n, p in zip(names, params): + for p in params: g = p.grad if g is None: continue @@ -692,87 +293,39 @@ class Muon(torch.optim.Optimizer): else: g = buf - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + # apply weight decay + p.data.mul_(1 - lr * wd) - scales_full = self._compute_scales(p, qk_clip_state) - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + # apply update + p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + g = g.add(buf, alpha=momentum) + else: + g = buf + return g - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay - p.data.mul_(1 - lr * weight_decay) + p.data.mul_(1 - lr * wd) # apply update p.data.add_(u, alpha=-adjusted_lr) - def get_qk_clip_info(self, n, qk_logits): - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, params, group, lr, wd, momentum): """ Perform a parallel optimization step using Muon. """ @@ -794,143 +347,44 @@ class Muon(torch.optim.Optimizer): p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None + params, group + ) - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + chunk_size = params[0].device_mesh.mesh.numel() # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - + enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: Union[float, torch.Tensor], - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) - def step(self, closure=None, qk_logits=None): + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None if closure is not None: @@ -938,127 +392,64 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, + self.base( + params, + group, lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, + wd=wd, + momentum=momentum, ) + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + return loss diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py old mode 100644 new mode 100755 index cb9efd677b388ebc299d6c4747eee701c96211f6..7cf68eab4638da3512b5e49541c916ebd12301f0 --- a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_b0230e7_dirty -ops = torch.ops._optimizer_b0230e7_dirty +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_b0230e7_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..18e4199f161721ea032af3872fa5f7d3cd95a862 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c75e42265f382addc71327ad5628e8a2414da5872791c975e384708c4acd549 +size 1883352 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so deleted file mode 100755 index b49ddd7a4dd2c980f9693a477e26221042cc85c5..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95 -size 1883344 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py old mode 100644 new mode 100755 index 4af25d55c528fb0db2272540838ab70eb3619194..0d614d55d721efac406c147b4f62e6c703a91107 --- a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py @@ -1,26 +1,14 @@ -import logging import math -import types from dataclasses import dataclass -from typing import List, Optional, Union, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard - -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 +from torch.distributed._tensor import DTensor # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -32,31 +20,26 @@ def _zeropower_via_newtonschulz5(G, steps): performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE + a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast - if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + X = X.bfloat16() # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T - return X + return X.to(G.dtype) @dataclass @@ -64,425 +47,92 @@ class _muon_state: # TODO: use Optional worker_rank: int | None = None gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None - - -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols @torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) - assert g.numel() == shard_elems - per_dst[dst].append(g) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - write_offsets = {id(p): 0 for p in owned_params} - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) - dst.copy_(sg) - - write_offsets[id(p)] += n - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None @torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ +def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) - - offset = 0 - for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) - assert n > 0 - - su = u_full.narrow(0, offset, n) - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += split_elems_for_src(p, rank, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = split_elems_for_src(p, rank, num_ranks) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - if rank == state.worker_rank: - # Free computed_u + # Clear u to free memory state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales(p, state.qk_clip_state) - if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: Optional[torch.Tensor] + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) class Muon(torch.optim.Optimizer): @@ -499,87 +149,71 @@ class Muon(torch.optim.Optimizer): - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + adamw_wd: The weight decay for the internal AdamW. """ - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - overlap_step=5): + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): defaults = dict( lr=lr, - weight_decay=weight_decay, + wd=adamw_wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, - use_muon=True, ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model - super().__init__(params, defaults) + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) - self.rank = None + self.rank = dist.get_rank() self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config - self.overlap_step = overlap_step + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,30 +231,7 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - if p.placements == (Shard(dim=0), ): - # Case for FSDP - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") - - def init_state_and_assign_params(self, names, params, group, qk_logits): + def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} @@ -636,44 +247,34 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) round_robin = 0 mesh = None - shard_mesh = None - process_group = None - for n, p in zip(ordered_names, ordered_params): + for p in ordered_params: if mesh is None: mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) elif mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) + param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def base(self, params, group, lr, wd, momentum): # generate weight updates in distributed fashion - for n, p in zip(names, params): + for p in params: g = p.grad if g is None: continue @@ -692,87 +293,39 @@ class Muon(torch.optim.Optimizer): else: g = buf - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + # apply weight decay + p.data.mul_(1 - lr * wd) - scales_full = self._compute_scales(p, qk_clip_state) - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + # apply update + p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + g = g.add(buf, alpha=momentum) + else: + g = buf + return g - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay - p.data.mul_(1 - lr * weight_decay) + p.data.mul_(1 - lr * wd) # apply update p.data.add_(u, alpha=-adjusted_lr) - def get_qk_clip_info(self, n, qk_logits): - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, params, group, lr, wd, momentum): """ Perform a parallel optimization step using Muon. """ @@ -794,143 +347,44 @@ class Muon(torch.optim.Optimizer): p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None + params, group + ) - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + chunk_size = params[0].device_mesh.mesh.numel() # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - + enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: Union[float, torch.Tensor], - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) - def step(self, closure=None, qk_logits=None): + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None if closure is not None: @@ -938,127 +392,64 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, + self.base( + params, + group, lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, + wd=wd, + momentum=momentum, ) + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + return loss diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__init__.py old mode 100644 new mode 100755 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py old mode 100644 new mode 100755 index cb9efd677b388ebc299d6c4747eee701c96211f6..7cf68eab4638da3512b5e49541c916ebd12301f0 --- a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_b0230e7_dirty -ops = torch.ops._optimizer_b0230e7_dirty +from . import _optimizer_036642a_dirty +ops = torch.ops._optimizer_036642a_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_b0230e7_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_036642a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..73f6a2b6170eaf2e12e73fa4cf70c620ae925bdd --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_036642a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a2363d4311d6a75fbcc03e6d4a71c73dae4d54e00a30135d25198d4078c6b0f +size 1749648 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so deleted file mode 100755 index 52c84f7911385aeed54b2abce05b257b4b498f14..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:21d5da3673206b979eaba9dd6d8918d7745ecd3bd3715e55105fd57c234a3a42 -size 1749776 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py old mode 100644 new mode 100755 index 4af25d55c528fb0db2272540838ab70eb3619194..0d614d55d721efac406c147b4f62e6c703a91107 --- a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py @@ -1,26 +1,14 @@ -import logging import math -import types from dataclasses import dataclass -from typing import List, Optional, Union, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard - -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 +from torch.distributed._tensor import DTensor # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -32,31 +20,26 @@ def _zeropower_via_newtonschulz5(G, steps): performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE + a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast - if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + X = X.bfloat16() # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T - return X + return X.to(G.dtype) @dataclass @@ -64,425 +47,92 @@ class _muon_state: # TODO: use Optional worker_rank: int | None = None gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None - - -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols @torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) - assert g.numel() == shard_elems - per_dst[dst].append(g) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - write_offsets = {id(p): 0 for p in owned_params} - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) - dst.copy_(sg) - - write_offsets[id(p)] += n - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None @torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ +def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) - - offset = 0 - for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) - assert n > 0 - - su = u_full.narrow(0, offset, n) - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += split_elems_for_src(p, rank, num_ranks) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = split_elems_for_src(p, rank, num_ranks) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - if rank == state.worker_rank: - # Free computed_u + # Clear u to free memory state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales(p, state.qk_clip_state) - if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: Optional[torch.Tensor] + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) class Muon(torch.optim.Optimizer): @@ -499,87 +149,71 @@ class Muon(torch.optim.Optimizer): - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + adamw_wd: The weight decay for the internal AdamW. """ - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - overlap_step=5): + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): defaults = dict( lr=lr, - weight_decay=weight_decay, + wd=adamw_wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, - use_muon=True, ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model - super().__init__(params, defaults) + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) - self.rank = None + self.rank = dist.get_rank() self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config - self.overlap_step = overlap_step + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,30 +231,7 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - if p.placements == (Shard(dim=0), ): - # Case for FSDP - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") - - def init_state_and_assign_params(self, names, params, group, qk_logits): + def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} @@ -636,44 +247,34 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) round_robin = 0 mesh = None - shard_mesh = None - process_group = None - for n, p in zip(ordered_names, ordered_params): + for p in ordered_params: if mesh is None: mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) elif mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) + param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def base(self, params, group, lr, wd, momentum): # generate weight updates in distributed fashion - for n, p in zip(names, params): + for p in params: g = p.grad if g is None: continue @@ -692,87 +293,39 @@ class Muon(torch.optim.Optimizer): else: g = buf - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + # apply weight decay + p.data.mul_(1 - lr * wd) - scales_full = self._compute_scales(p, qk_clip_state) - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + # apply update + p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + g = g.add(buf, alpha=momentum) + else: + g = buf + return g - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay - p.data.mul_(1 - lr * weight_decay) + p.data.mul_(1 - lr * wd) # apply update p.data.add_(u, alpha=-adjusted_lr) - def get_qk_clip_info(self, n, qk_logits): - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, params, group, lr, wd, momentum): """ Perform a parallel optimization step using Muon. """ @@ -794,143 +347,44 @@ class Muon(torch.optim.Optimizer): p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None + params, group + ) - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + chunk_size = params[0].device_mesh.mesh.numel() # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - + enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: Union[float, torch.Tensor], - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) - def step(self, closure=None, qk_logits=None): + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None if closure is not None: @@ -938,127 +392,64 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, + self.base( + params, + group, lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, + wd=wd, + momentum=momentum, ) + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + return loss diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a218cd77694938fb0914270a5c6416a684d50cb3..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:222315672693e6d4544b1eee4772dc7be744b3794cfd6ff370a6f46d782386a1 -size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 1cf60567b59ce1b343c5a44301e443953b674f78..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:119adc22cd57de6d6d78c1f5094310b57083050f40836a5455bdb6c35bed104b -size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index e996c45edb033c93ec8a41716764cdcbbd04593d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7e8463be5f48aba32d645183945d258cdb532b238ef40665db396b459367cad1 -size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 19ee075424c40e1714e4ef6561d68c368e933792..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:90ac494e1381bedf95832a91c108ff18d900442203f9b0612efa5519956def2e -size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index d23cf944ec31a3606755cdac0f39bae6455816d5..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5ddeadf7e678e0ff7e84b9e4f869ef45ed6840b06e9093e20210769fd15b8cad -size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index ca73c2a576e1ad27e2c5a403c459246792b9a9d1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:07135e56b4c66b79fcb062c0bd39e61dae7e4251f164638cd09f8e360075f215 -size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index b4ccc5bd24c68e412968b43af9a352dd5ac27863..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f048516a9820c335263f335df545e404e22ee146355b49669c95a54852448542 -size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 67ccafc522c41f14eaf682f265f2bc7d3f56b114..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ef9fba09368a2296ebad017f6576f119ebe2b9513be0d51b66b403fe942bb6d5 -size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 926869eca5ee9c6a8f6899f3966ba361bc640faa..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c1574fefc74653a663d8c4c53dda381d92c60cdc29358f15618b1b746dc4ae4e -size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py deleted file mode 100644 index 239c7a65f8293e7d0df28f05fce645af56d628c0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muon import Muon - -__all__ = [ - "Muon", -] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py deleted file mode 100644 index e6f6fcf6280e969b1761926112147d3146e27b59..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 95d54a0288c1e9cea520f5e3042a163cb9222346..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6ad69fa088ef05b1697f74d59c1a5a12f17dbf2a3cddb8c6b92ed7543b4cbdbc -size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py deleted file mode 100644 index dbf25575f185ff379789482068e4ecf55b9455a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py +++ /dev/null @@ -1,1268 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) - - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) - - p.copy_(p_sharded) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params - - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/docs/muon/balanced.png b/docs/muon/balanced.png new file mode 100644 index 0000000000000000000000000000000000000000..2076978a5a0149d598b419bfc45c508405dca0df --- /dev/null +++ b/docs/muon/balanced.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9933e2cd5490513593dd6cf1c5c4f18b7f33fd6e6b11c696784269c2bb78055b +size 98003 diff --git a/docs/muon/distributed_muon.png b/docs/muon/distributed_muon.png new file mode 100644 index 0000000000000000000000000000000000000000..26544c9e035afae48d1b32cd6ae729c600a47f33 --- /dev/null +++ b/docs/muon/distributed_muon.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31caea472991fd24a7934bf211b5adcbf154b5295bfe364bba5b603851c2cfae +size 407912 diff --git a/docs/muon/distributed_muon_execution.png b/docs/muon/distributed_muon_execution.png new file mode 100644 index 0000000000000000000000000000000000000000..824c728b78c73ca0d5b70a169ed2e5e50a59946c --- /dev/null +++ b/docs/muon/distributed_muon_execution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72ab4d8076f1e182900d71636dd22c32b20bface38890cef72a0c94c496d5f02 +size 57140 diff --git a/docs/muon/imbalance.png b/docs/muon/imbalance.png new file mode 100644 index 0000000000000000000000000000000000000000..d63f0a034912195910cfac8a49f0533ac99968b1 --- /dev/null +++ b/docs/muon/imbalance.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c71d5faed05d46b2269fefa3b6bea6791d7bf51744f47aa4bb8c311eda1b27ff +size 56528 diff --git a/docs/muon/main.tex b/docs/muon/main.tex new file mode 100644 index 0000000000000000000000000000000000000000..41ad35f7b47a0bcfd9cc5bcc879b5fa1bf56c6f4 --- /dev/null +++ b/docs/muon/main.tex @@ -0,0 +1,142 @@ +\documentclass{article} +\usepackage{graphicx} +\usepackage{hyperref} +\usepackage{amsmath} +\usepackage{caption} +\usepackage{tgtermes} +\usepackage{float} +\usepackage[a4paper, margin=1in]{geometry} +\usepackage{booktabs} +\usepackage{algorithm} +\usepackage{algorithmicx} +\usepackage{algpseudocode} +\date{} + +\begin{document} + +{\LARGE \bfseries Parallelize Muon with FSDP2 \par} +\vspace{1em} % 제목 아래 간격 조정 + +\section*{Motivation} + +\begin{figure}[H] + \centering + \includegraphics[width=0.8\textwidth]{distributed_muon.png} + \caption*{Distributed Muon by Moonlight} +\end{figure} + +While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} + \caption*{Execution timeline of Distributed Muon} +\end{figure} + +\begin{itemize} + \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient + \item \texttt{AG[i]} : AllGather i-th gradient + \item \texttt{G[i]} : Gather i-th gradient + \item \texttt{SC[i]} : Scatter i-th gradient +\end{itemize} +\clearpage +\section*{Algorithm} + +\subsection*{Parallel Muon} + +\begin{algorithm} +\caption{Parallel Muon} +\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ +\begin{algorithmic}[1] +\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} +\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ +\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} +\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ +\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} +\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ +\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} +\If{$\mathbf{r}$ == $\mathbf{R}$} + \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ +\Else + \State $\mathbf{u} \gets None$ +\EndIf + +\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} +\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ +\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} +\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ +\State \textbf{return $\mathbf{p'}$} +\end{algorithmic} +\end{algorithm} + +We eliminate redundant computation by assigning each parameter to a specific GPU. + +However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{naive_execution.png} + \caption*{Execution timeline of Parallel Muon} +\end{figure} + +\subsection*{Scheduling Sub-Operations} + +We can schedule the whole sub-operations as follows, due to the following reasons: +\begin{itemize} + \item There are no dependencies between parameters. + \item GPUs can execute computation and communication concurrently. +\end{itemize} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{pipelined.png} + \caption*{Execution timeline of re-scheduled Parallel Muon} +\end{figure} + +We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. + +\textbf{[Algorithm]} (To be written) +\clearpage +\subsection*{Load Balancing} + +If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ +To mitigate this, we apply load balancing based on per-parameter FLOPs. + +\vspace{1em} +\textbf{Imbalanced (Round Robin)} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{imbalance.png} +\end{figure} + +\textbf{After Load Balancing} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{balanced.png} +\end{figure} + +\section*{Implementation} + +The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. +To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. + +Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. + +\section*{Evaluation} +We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step. + +\begin{table}[H] + \centering + \begin{tabular}{@{}lllll@{}} + \toprule + Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ + \midrule + 10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ + \bottomrule + \end{tabular} +\end{table} +Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. + +\end{document} \ No newline at end of file diff --git a/docs/muon/naive_execution.png b/docs/muon/naive_execution.png new file mode 100644 index 0000000000000000000000000000000000000000..e8f3c4ce721cda02eb95f569c58739d36008b525 --- /dev/null +++ b/docs/muon/naive_execution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaacd3625f33cee9735ed0d96b95f98c696dfc771976be970a38c991e2ce84ab +size 42729 diff --git a/docs/muon/parallel_muon.pdf b/docs/muon/parallel_muon.pdf new file mode 100644 index 0000000000000000000000000000000000000000..8321c572edfae32e963a013d69187d58971fc27e --- /dev/null +++ b/docs/muon/parallel_muon.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1a88537a50ecc3db52d6e148d3513b31e2c9810c09df0da8f6aff03fa652fe5 +size 654538 diff --git a/docs/muon/pipelined.png b/docs/muon/pipelined.png new file mode 100644 index 0000000000000000000000000000000000000000..7e3d51f98c8f2e501704298c6ec48dca08203884 --- /dev/null +++ b/docs/muon/pipelined.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1f8043cc58e7d8d9da5694ad7bccd1b9fe0210349b9aa9a62652a97f75cf097 +size 64316 diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..368754a84e467fe6ba68962628649fc9ab6121cc --- /dev/null +++ b/flake.lock @@ -0,0 +1,167 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1748598786, + "owner": "huggingface", + "repo": "hf-nix", + "rev": "6ca679441494139fde1f2355691ddb5dc8170269", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1749822059, + "narHash": "sha256-zype8KSqESZUIQpsY6sbf4f9pPxM/Zwem+KuH5LeHFk=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "96abd968baa5fa16217413050fa7372d5db3baa5", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "d3c1681180717528068082103bf323147de6ab0b", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudatoolkit-12.9-kernel-builder", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/optimizer/dummy.cu b/optimizer/dummy.cu index 4a37c61c2feb239dbe063d08fd763ee7d0bb86ad..9a9780b635e46a946e7c836cd390c24e41da3385 100644 --- a/optimizer/dummy.cu +++ b/optimizer/dummy.cu @@ -3,4 +3,4 @@ namespace { __global__ void dummy() { // This kernel does nothing but serves as a placeholder } -} // namespace +} diff --git a/test/README.md b/test/README.md deleted file mode 100644 index 35ae009acbca01163cc3a3cddac000fa957c0a4f..0000000000000000000000000000000000000000 --- a/test/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# Muon Optimizer Test - -This directory contains a test script for the **Muon optimizer**. - -## Prerequisites - -- **GPU Requirement** - - All tests require **8 GPUs** by default. - - If you have fewer GPUs available: - - Modify the parallelism configurations in `test_muon.py`. - -- **Model Access** - - The tests require access to the private model repository: - - `Motif-Technologies/Motif-2.6B-4layer-random` on Hugging Face. - - Set your Hugging Face token via the environment variable `HF_TOKEN`. - - If you don’t have access, please contact the maintainer. - -- **Using a Different Model (Optional)** - - You may modify the test to use a different model by: - - Updating the model name in `conftest.py::inputs`. - - Adjusting the tensor parallel rules in `utils.py::_apply_tp`. - -## Usage - -- To execute the test with 8 GPUs, simply run: - -```bash -./run_test.sh -``` - -- To check the other available options, you can use: - -```bash -pytest --help -... -Custom options: - --measure-perf Measure execution time and peak memory usage during optimizer step. - --do-profile Enable profiling during tests. - --skip-verify Skip verification of optimizer step correctness with sequential implementation. - This can be useful when GPU memory is limited. -... -``` diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 15177262eb39e8f60c95742bb372faf2f3857ae9..0000000000000000000000000000000000000000 --- a/test/conftest.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging - -import pytest -import torch -import torch.distributed as dist -from packaging import version -from transformers import AutoModelForCausalLM - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -SEED = 0xdeadbeef - - -def pytest_addoption(parser): - parser.addoption( - "--measure-perf", - action="store_true", - default=False, - help= - "Measure execution time and peak memory usage during optimizer step.", - ) - - parser.addoption( - "--do-profile", - action="store_true", - default=False, - help="Enable profiling during tests.", - ) - - parser.addoption( - "--skip-verify", - action="store_true", - default=False, - help= - "Skip verification of optimizer step correctness with sequential implementation.\n" - "This can be useful when GPU memory is limited.", - ) - - -def pytest_configure(config): - if config.getoption( - "--do-profile") and not config.getoption("--measure-perf"): - raise pytest.UsageError( - "--do-profile requires --measure-perf. Please enable both flags.") - - -@pytest.fixture(scope="session") -def measure_perf(request): - return request.config.getoption("--measure-perf") - - -@pytest.fixture(scope="session") -def do_profile(request): - return request.config.getoption("--do-profile") - - -@pytest.fixture(scope="session") -def skip_verify(request): - return request.config.getoption("--skip-verify") - - -@pytest.fixture(scope="session", autouse=True) -def init_dist(request): - if version.parse(torch.__version__) < version.parse("2.8"): - pytest.skip("torch>=2.8.0 is required for parallel muon") - return - - try: - dist.init_process_group(backend="nccl") - torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - except Exception as e: - print(f"Failed to initialize torch.distributed: {e}") - pytest.skip("Failed to initialize torch.distributed") - - if dist.get_world_size() != 8: - pytest.skip("Need 8 processes in dist group. " - "You can run with `torchrun --nproc-per-node=8 " - "--local-ranks-filter 0 -m pytest " - "test_rms_norm_sequence_parallel.py`." - "To run with less than 8 gpus, modify " - "the test cases accordingly.") - - yield - dist.destroy_process_group() - - -@pytest.fixture(scope="session") -def inputs(): - """Load Motif-2.6B model and generate random gradients for testing. - Returns: - tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]: - - torch.nn.Module: The Motif-2.6B model. - - list[torch.Tensor]: A list of random gradients for each model parameter. - - dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits. - """ - model_name = "Motif-Technologies/Motif-2.6B-4layer-random" - - torch.manual_seed(SEED) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(SEED) - - model = AutoModelForCausalLM.from_pretrained( - model_name, - trust_remote_code=True, - ) - logger.info( - f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)" - ) - - grads: list[torch.Tensor] = [] - for param in model.parameters(): - grad = torch.randn_like(param, device=param.device, dtype=param.dtype) - grads.append(grad) - - qk_logits: dict[int, torch.Tensor] = { - i: - torch.randn(model.config.num_attention_heads, - device=model.device, - dtype=torch.bfloat16) - for i in range(model.config.num_hidden_layers) - } - - return [model, grads, qk_logits] diff --git a/test/optimizer b/test/optimizer deleted file mode 120000 index c7ff828a90e1c2a67535184c5e89724fb52bea24..0000000000000000000000000000000000000000 --- a/test/optimizer +++ /dev/null @@ -1 +0,0 @@ -../torch-ext/optimizer/ \ No newline at end of file diff --git a/test/pytest.ini b/test/pytest.ini deleted file mode 100644 index 11c72fa2e2812b16b1c2e92fb0d78cb4adbda2e5..0000000000000000000000000000000000000000 --- a/test/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -log_cli = true -log_cli_level = INFO diff --git a/test/run_test.sh b/test/run_test.sh deleted file mode 100755 index 2c2bd5b362a36dd4facebee1454eb7b4809118f1..0000000000000000000000000000000000000000 --- a/test/run_test.sh +++ /dev/null @@ -1 +0,0 @@ -torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py diff --git a/test/test_muon.py b/test/test_muon.py deleted file mode 100644 index 3c4085963941120b0c089bfbdfad3a840c00da20..0000000000000000000000000000000000000000 --- a/test/test_muon.py +++ /dev/null @@ -1,244 +0,0 @@ -import copy -import logging -import time -from contextlib import nullcontext - -import pytest -import torch -import torch.distributed as dist -from optimizer.muon import Muon, get_default_muon_param_groups -from torch.distributed.tensor import DTensor, Replicate -from torch.profiler import ProfilerActivity, profile - -from .utils import (ParallelDims, assert_params_equal, parallelize_motif, - parallelize_qk_logits) - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def apply_muon_step( - model: torch.nn.Module, - parallel_dims: ParallelDims | None, - grads: list[torch.Tensor], - warmup_step: int, - chunk_size: int, - small_param_numel_threshold: int, - qk_logits: dict[int, torch.Tensor] | None = None, - use_distributed_muon: bool = False, - measure_perf: bool = False, - do_profile: bool = False, -) -> tuple[torch.nn.Module, tuple[float, float] | None]: - """ apply single Muon step with optional QK clipping """ - - # 1. Apply gradients to model parameters - assert len(grads) == len(list(model.parameters())) - for grad, param in zip(grads, model.parameters()): - grad = grad.to(param.device) - if isinstance(param.data, DTensor): - unsharded_grad = DTensor.from_local( - grad, - device_mesh=param.data.device_mesh, - placements=[Replicate()] * param.data.device_mesh.ndim, - ) - sharded_grad = unsharded_grad.redistribute( - device_mesh=param.data.device_mesh, - placements=param.data.placements) - param.grad = sharded_grad - else: - param.grad = grad - - # 2. Setup Muon optimizer - params = get_default_muon_param_groups(model) - clip_config = dict({ - "q_indices": - list(range(model.config.num_attention_heads)), - "k_indices": - list(range(model.config.num_attention_heads)), - "head_dim": - model.config.hidden_size // model.config.num_attention_heads, - "threshold": - 0.5 - }) - optim = Muon( - params=params, - clip_config=clip_config if qk_logits is not None else None, - none_grad=False, - warmup_step=warmup_step, - chunk_size=chunk_size, - small_param_numel_threshold=small_param_numel_threshold, - use_distributed_muon=use_distributed_muon, - ) - - optim.step(qk_logits=qk_logits) - - timing_result: tuple[float, float] | None = None - - if measure_perf: - # extra warm up - optim.step(qk_logits=qk_logits) - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - num_iters = 20 - current_mem = torch.cuda.memory_allocated() - - if do_profile: - context = profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True) - else: - context = nullcontext() - - with context as prof: - for _i in range(num_iters): - optim.step(qk_logits=qk_logits) - - end.record() - end.synchronize() - - if prof is not None and dist.get_rank() == 0: - date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) - profile_name = "trace" - profile_name += f"_{date}" - profile_name += f"_{parallel_dims}" - profile_name += f"_{chunk_size}" - profile_name += f"_{warmup_step}" - profile_name += f"_{qk_logits is not None}" - profile_name += f"_{use_distributed_muon}" - - prof.export_chrome_trace(f"{profile_name}.json") - - peak_memory = torch.cuda.max_memory_allocated() - current_mem - - elapsed_time_ms = start.elapsed_time(end) / num_iters - - timing_result = (elapsed_time_ms, peak_memory) - - return model, timing_result - - -@pytest.fixture(scope="session") -def sequential_muon_result( - skip_verify, # from conftest.py - inputs # from conftest.py -) -> dict[bool, torch.nn.Module]: - """Run Muon optimizer to sequential model for baseline results.""" - if skip_verify: - logger.info("Skipping verification tests as per user request") - return None - - model, grads, qk_logits = inputs - - result = apply_muon_step( - model=copy.deepcopy(model).cuda(), - parallel_dims=None, - grads=grads, - warmup_step=-1, - chunk_size=-1, - small_param_numel_threshold=-1, - qk_logits=None, - )[0].cpu() - - result_qk_clip = apply_muon_step( - model=copy.deepcopy(model).cuda(), - parallel_dims=None, - grads=grads, - warmup_step=-1, - chunk_size=-1, - small_param_numel_threshold=-1, - qk_logits=qk_logits, - )[0].cpu() - - return { - False: result, - True: result_qk_clip, - } - - -OVERLAP_STEPS = [5] -CHUNK_SIZES = [8] -SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000] - - -@pytest.mark.parametrize("parallel_dims", [ - pytest.param(ParallelDims(8, 1, 1), id="base"), - pytest.param(ParallelDims(1, 8, 1), id="fsdp"), - pytest.param(ParallelDims(2, 4, 1), id="hsdp"), - pytest.param(ParallelDims(1, 1, 8), id="tp"), - pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), - pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), -]) -@pytest.mark.parametrize("apply_qk_clip", [False, True]) -@pytest.mark.parametrize("use_distributed_muon", [False]) -@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) -@pytest.mark.parametrize("chunk_size", CHUNK_SIZES) -@pytest.mark.parametrize("small_param_numel_threshold", - SMALL_PARAM_NUMEL_THRESHOLDS) -def test_parallel_muon( - request, - sequential_muon_result: dict[bool, torch.nn.Module], - parallel_dims: ParallelDims, - apply_qk_clip: bool, - use_distributed_muon: bool, - warmup_step: int, - chunk_size: int, - small_param_numel_threshold: int, - inputs: tuple[torch.nn.Module, list[torch.Tensor], - dict[int, torch.Tensor]], # from conftest.py - measure_perf, # from conftest.py - do_profile, # from conftest.py -) -> None: - if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: - pytest.skip("Distributed Muon does not effected by chunk size") - if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: - pytest.skip("Distributed Muon does not effected by warmup step") - - model, grads, qk_logits = inputs - - if not apply_qk_clip: - qk_logits = None - - # Deepcopy the model to avoid in-place modification - model = copy.deepcopy(model).cuda() - - parallelized_model = parallelize_motif(model, parallel_dims) - - if qk_logits is not None: - # Deepcopy the qk logits to avoid in-place modification - qk_logits = copy.deepcopy(qk_logits) - qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) - - parallelized_model, timing_result = apply_muon_step( - model=parallelized_model, - parallel_dims=parallel_dims, - grads=grads, - warmup_step=warmup_step, - chunk_size=chunk_size, - small_param_numel_threshold=small_param_numel_threshold, - qk_logits=qk_logits, - use_distributed_muon=use_distributed_muon, - measure_perf=measure_perf, - do_profile=do_profile, - ) - - if measure_perf: - assert timing_result is not None - avg_time_ms, peak_memory = timing_result - logger.info( - f"\nParallel dims: {parallel_dims}, " - f"\nUse distributed Muon: {use_distributed_muon}, " - f"\nApply QK clip: {apply_qk_clip} => " - f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" - f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," - ) - - if sequential_muon_result is None: - logger.info("Skipping correctness check as sequential result is None") - elif measure_perf: - logger.info("Skipping correctness check as timing is enabled") - else: - assert_params_equal(parallelized_model, - sequential_muon_result[apply_qk_clip]) diff --git a/test/utils.py b/test/utils.py deleted file mode 100644 index 494c09de1f3241a5ef5028e47f21d17c7342645a..0000000000000000000000000000000000000000 --- a/test/utils.py +++ /dev/null @@ -1,241 +0,0 @@ -from dataclasses import dataclass - -import torch -import torch.distributed as dist -from torch.distributed.fsdp import fully_shard -from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard -from torch.distributed.tensor.parallel import (ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, - parallelize_module) - - -@dataclass -class ParallelDims: - dp_replicate_degree: int - dp_shard_degree: int - tp_degree: int - - def __str__(self) -> str: - return (f"dp_replicate-{self.dp_replicate_degree}_" - f"dp_shard-{self.dp_shard_degree}_" - f"tp-{self.tp_degree}") - - -def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: - """Constructs a DeviceMesh based on the given parallel dimensions. - - Args: - parallel_dims (ParallelDims): The parallelism configuration. - - Returns: - DeviceMesh: The constructed device mesh. - """ - world_size = dist.get_world_size() - expected_devices = (parallel_dims.dp_replicate_degree * - parallel_dims.dp_shard_degree * - parallel_dims.tp_degree) - if world_size < expected_devices: - raise ValueError( - f"Not enough devices: found {world_size}, " - f"but expected at least {expected_devices}. ({parallel_dims})") - - degrees = [ - parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, - parallel_dims.tp_degree - ] - dim_names = ["dp_replicate", "dp_shard", "tp"] - - mesh_shape = [] - mesh_dim_names = [] - for degree, dim_name in zip(degrees, dim_names): - if degree > 1: - mesh_shape.append(degree) - mesh_dim_names.append(dim_name) - - device_mesh = dist.init_device_mesh("cuda", - mesh_shape, - mesh_dim_names=mesh_dim_names) - - return device_mesh - - -def _apply_tp( - model: torch.nn.Module, - tp_mesh: DeviceMesh, -): - """Apply tensor parallelism.""" - - # Layer names must match Motif model definition - # https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py - - assert type(model).__name__ == "MotifForCausalLM" - - # 1. Parallelize the embedding and shard its outputs (which are the first - # transformer block's inputs) - # 2. Parallelize the root norm layer over the sequence dim - # 3. Parallelize the final linear output layer - - parallelize_module( - model, - tp_mesh, - { - # This below separate tie_weights and make difficult to compare - # the answer with non-tensor-parallel version. - # TODO(jeesoo): check correctness for training semantic - - #"model.embed_tokens": - #RowwiseParallel( - # input_layouts=Replicate(), - # output_layouts=Shard(1), - #), - "model.norm": - SequenceParallel(), - "output": - ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1), # loss_parallel - use_local_output=False, - ), - }, - ) - - # Apply tensor + sequence parallelism to every transformer block - for transformer_block in model.model.layers: - layer_plan = { - "input_layernorm": - SequenceParallel(), - "post_attention_layernorm": - SequenceParallel(), - "self_attn": - PrepareModuleInput( - # x, freqs_cis, attention_mask, position_ids, qk_clip - input_layouts=(Shard(1), Replicate(), None, None, None), - desired_input_layouts=(Replicate(), Replicate(), None, None, - None), - ), - "self_attn.q_proj": - ColwiseParallel(), - "self_attn.k_proj": - ColwiseParallel(), - "self_attn.v_proj": - ColwiseParallel(), - "self_attn.o_proj": - RowwiseParallel(output_layouts=Shard(1)), - "mlp": - PrepareModuleInput( - input_layouts=(Shard(1), ), - desired_input_layouts=(Replicate(), ), - ), - "mlp.gate_proj": - ColwiseParallel(), - "mlp.down_proj": - RowwiseParallel(output_layouts=Shard(1)), - "mlp.up_proj": - ColwiseParallel(), - } - - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=layer_plan, - ) - - -def _apply_fsdp( - model: torch.nn.Module, - dp_mesh: DeviceMesh, -): - for layer in model.model.layers: - fully_shard(layer, mesh=dp_mesh) - layer.reshard() - fully_shard(model, mesh=dp_mesh) - model.reshard() - - -def parallelize_motif(model: torch.nn.Module, - parallel_dims: ParallelDims) -> torch.nn.Module: - """Parallelize the Motif model according to the given parallel dimensions. - - Args: - model (torch.nn.Module): The Motif model to be parallelized. - parallel_dims (ParallelDims): The parallelism configuration. - - Returns: - torch.nn.Module: The parallelized Motif model. - """ - - mesh = _construct_device_mesh(parallel_dims) - - if parallel_dims.tp_degree > 1: - _apply_tp(model, mesh["tp"]) - - if parallel_dims.dp_shard_degree > 1: - if parallel_dims.dp_replicate_degree > 1: - dp_dim_names = ("dp_replicate", "dp_shard") - else: - dp_dim_names = ("dp_shard", ) - _apply_fsdp(model, mesh[dp_dim_names]) - - return model - - -def parallelize_qk_logits( - qk_logits: dict[int, torch.Tensor], - parallel_dims: ParallelDims, -) -> dict[int, torch.Tensor]: - """Parallelize the QK logits according to the given parallel dimensions. - - Args: - qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. - parallel_dims (ParallelDims): The parallelism configuration. - - Returns: - dict[int, torch.Tensor]: The parallelized QK logits. - """ - - mesh = _construct_device_mesh(parallel_dims) - - if parallel_dims.tp_degree > 1: - tp_rank = mesh["tp"].get_local_rank() - placements = [ - Shard(0) if dim_name == "tp" else Replicate() - for dim_name in mesh.mesh_dim_names - ] - for layer_idx, logits in qk_logits.items(): - assert logits.size(0) % parallel_dims.tp_degree == 0 - local_logits = logits.chunk(parallel_dims.tp_degree, - dim=0)[tp_rank].contiguous() - - qk_logits[layer_idx] = DTensor.from_local( - local_tensor=local_logits, - device_mesh=mesh, - placements=placements, - ) - - return qk_logits - - -def assert_params_equal(actual: torch.nn.Module, - expected: torch.nn.Module) -> None: - """Asserts that the parameters of two models are equal. - - Args: - actual (torch.nn.Module): The actual model. - expected (torch.nn.Module): The expected model. - Returns: - None - """ - - def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: - if isinstance(param.data, DTensor): - return param.data.full_tensor() - return param.data - - for (name_p, p), (name_s, s) in zip(actual.named_parameters(), - expected.named_parameters()): - p = get_full_param(p.cuda()) - s = get_full_param(s.cuda()) - - torch.testing.assert_close(p, s, atol=0, rtol=0) diff --git a/torch-ext/optimizer/distributed/utils.py b/torch-ext/optimizer/distributed/utils.py deleted file mode 100644 index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..0000000000000000000000000000000000000000 --- a/torch-ext/optimizer/distributed/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, - ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/torch-ext/optimizer/matmul_transpose_triton.py b/torch-ext/optimizer/matmul_transpose_triton.py deleted file mode 100644 index 4565b2c4fd506a4218340d380d6c962b16774b1d..0000000000000000000000000000000000000000 --- a/torch-ext/optimizer/matmul_transpose_triton.py +++ /dev/null @@ -1,128 +0,0 @@ -# MIT License -# -# Copyright (c) 2025 Tianyang Lin -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import triton -import triton.language as tl - - -def get_autotune_config(): - return [ - triton.Config( - { - 'BLOCK_SIZE_M': blk_m, - 'BLOCK_SIZE_K': blk_k, - 'GROUP_SIZE_M': grp_sz - }, - num_stages=n_stages, - num_warps=n_warps) for blk_m in [32, 64, 128] - for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] - for n_warps in [4, 8] - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'K'], -) -@triton.jit -def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Core kernel jit function of matmul_transpose that computes y = x @ x.T - The code is a simple adaptation from the triton `matmul` tutorial: - https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - if pid_m > pid_n: - return - - offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - # we use a & b ptrs to denote different rows of x. - a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) - b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) - a_ptrs += BLOCK_SIZE_K * stride_xk - b_ptrs += BLOCK_SIZE_K * stride_xk - # use dtype.element_ty to accommodate different input datatypes as in cpp templates - # https://github.com/triton-lang/triton/issues/2252 - c = accumulator.to(x.dtype.element_ty) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, c, mask=c_mask) - - # transpose and copy - if pid_m < pid_n: - ct_ptrs = y + stride_ym * offs_cn[:, - None] + stride_yn * offs_cm[None, :] - ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) - - -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - - d_in = d_in.contiguous() - M, K = d_in.shape - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - M, META['BLOCK_SIZE_M']), ) - with torch.cuda.device(d_in.device.index): - mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), - d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..0d614d55d721efac406c147b4f62e6c703a91107 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -1,32 +1,14 @@ -import logging import math -import types -from collections import defaultdict from dataclasses import dataclass -from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 +from torch.distributed._tensor import DTensor # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -38,499 +20,119 @@ def _zeropower_via_newtonschulz5(G, steps): performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE + a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast - if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + X = X.bfloat16() # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T - return X + return X.to(G.dtype) @dataclass class _muon_state: # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None + worker_rank: int | None = None gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel @torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event +def _gather(p, state, rank, comm_stream, none_grad): + g = p.grad + mesh = g.device_mesh + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None + if rank == state.worker_rank: + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None @torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ +def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() + # Clear the gathered gradient to free memory + state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - +def _scatter(p, state, lr, wd, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), ) - - state.scattered_u = u_dtensor - if rank == state.worker_rank: - # Free computed_u + # Clear u to free memory state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + p.data.mul_(1 - lr * wd) + p.data.add_(u, alpha=-lr) class Muon(torch.optim.Optimizer): @@ -547,99 +149,71 @@ class Muon(torch.optim.Optimizer): - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + adamw_wd: The weight decay for the internal AdamW. """ - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False, - small_param_numel_threshold=65536): + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + ): defaults = dict( lr=lr, - weight_decay=weight_decay, + wd=adamw_wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, - use_muon=True, ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model - super().__init__(params, defaults) + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) - self.rank = None + self.rank = dist.get_rank() self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -657,28 +231,7 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): + def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} @@ -694,88 +247,34 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) + round_robin = (round_robin + 1) % mesh.mesh.numel() return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def base(self, params, group, lr, wd, momentum): # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - - for n, p in zip(names, params): + for p in params: g = p.grad if g is None: continue @@ -783,130 +282,50 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - g = self._update_g(p, g, group, momentum) - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) else: - g_full = g - p_full = p + g = buf - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None - - if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) - - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, - ) + # apply weight decay + p.data.mul_(1 - lr * wd) - p.copy_(p_sharded) + # apply update + p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + g = g.add(buf, alpha=momentum) else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): """ Perform a parallel optimization step using Muon. """ @@ -928,341 +347,109 @@ class Muon(torch.optim.Optimizer): p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None + params, group + ) - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - + enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - name_dtensors = [] - - param_tensors = [] - name_tensors = [] - - param_dtensors_small = [] - name_dtensors_small = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") - - def group_dtensors(dtensors, names): - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + torch.cuda.current_stream().wait_stream(self.comm_stream) - assert len(dtensors) == len(names) - for p, n in zip(dtensors, names): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - return placement_to_params + def step(self, closure=None): + """Perform a single optimization step. - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) + for group in self.param_groups: + ############################ + # Muon # + ############################ - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] - dtensor_group = group_dtensors(param_dtensors, name_dtensors) - for _, (names, params) in dtensor_group.items(): + if isinstance(params[0].data, DTensor): self.parallel( - names, params, group, lr=lr, - weight_decay=weight_decay, + wd=wd, momentum=momentum, - qk_logits=qk_logits, ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - @torch.no_grad - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. + ############################ + # AdamW backup # + ############################ - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) return loss