| # CLAUDE.md | |
| This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. | |
| ## Project Overview | |
| Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard). | |
| ## Commands | |
| ### Lint & Format | |
| ```bash | |
| pre-commit run --all-files # Run all pre-commit hooks | |
| pre-commit run isort --all-files # Run a specific hook (e.g., isort) | |
| ``` | |
| Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions). | |
| ### Tests | |
| Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0. | |
| ```bash | |
| cd test && ./run_test.sh | |
| # Equivalent to: | |
| cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py | |
| ``` | |
| Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation). | |
| ### Build | |
| Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`. | |
| ### Commit Convention | |
| **Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches. | |
| ## Architecture | |
| ### Source Layout | |
| ``` | |
| torch-ext/optimizer/ | |
| βββ __init__.py # Public API: exports Muon | |
| βββ muon.py # Muon optimizer class (~430 lines) | |
| βββ newton_schulz.py # Newton-Schulz iteration (~50 lines) | |
| βββ qk_clip.py # QK clipping for attention heads (~130 lines) | |
| βββ core.py # Shared state, helpers, param grouping (~110 lines) | |
| βββ pipeline.py # Async generator pipeline for parallel mode (~290 lines) | |
| βββ async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines) | |
| βββ adamw.py # Fused AdamW for non-Muon parameters (~160 lines) | |
| βββ matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines) | |
| βββ distributed/ | |
| βββ utils.py # Shard mesh construction, DTensor slicing (~175 lines) | |
| ``` | |
| ### Optimizer Modes | |
| The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure: | |
| 1. **Base mode** (`base()`) β Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization. | |
| 2. **Distributed mode** (`distributed_muon()`) β Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback. | |
| 3. **Parallel mode** (`parallel()`) β Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature. | |
| ### Parallel Mode Pipeline | |
| The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through: | |
| ``` | |
| build bufs + async all2all_gather β yield β wait + Newton-Schulz compute + async all2all_scatter β yield β wait + update_param | |
| ``` | |
| The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield. | |
| `warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`. | |
| For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md). | |
| ### Key Abstractions | |
| - **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) β Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default. | |
| - **`_muon_state` dataclass** (`core.py`) β Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state. | |
| - **`muon_chunk_pipeline()` generator** (`pipeline.py`) β Processes one chunk through the full gatherβcomputeβscatterβupdate pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables. | |
| - **`run_pipeline()`** (`async_utils.py`) β Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points. | |
| - **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) β Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+). | |
| - **Newton-Schulz iteration** (`newton_schulz.py`) β `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T. | |
| - **QK Clipping** (`qk_clip.py`) β Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`. | |
| - **Fused AdamW** (`adamw.py`) β Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement. | |
| ### Dependency Graph | |
| ``` | |
| matmul_transpose_triton.py (leaf) | |
| β | |
| newton_schulz.py (leaf + triton) | |
| β | |
| core.py ββββ qk_clip.py (leaf, distributed/utils) | |
| β β β | |
| β pipeline.py βββ async_utils.py | |
| β β | |
| β adamw.py | |
| β β | |
| muon.py (all above) | |
| β | |
| __init__.py | |
| ``` | |