# CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Project Overview Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard). ## Commands ### Lint & Format ```bash pre-commit run --all-files # Run all pre-commit hooks pre-commit run isort --all-files # Run a specific hook (e.g., isort) ``` Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions). ### Tests Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0. ```bash cd test && ./run_test.sh # Equivalent to: cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py ``` Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation). ### Build Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`. ### Commit Convention **Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches. ## Architecture ### Source Layout ``` torch-ext/optimizer/ ├── __init__.py # Public API: exports Muon ├── muon.py # Muon optimizer class (~430 lines) ├── newton_schulz.py # Newton-Schulz iteration (~50 lines) ├── qk_clip.py # QK clipping for attention heads (~130 lines) ├── core.py # Shared state, helpers, param grouping (~110 lines) ├── pipeline.py # Async generator pipeline for parallel mode (~290 lines) ├── async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines) ├── adamw.py # Fused AdamW for non-Muon parameters (~160 lines) ├── matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines) └── distributed/ └── utils.py # Shard mesh construction, DTensor slicing (~175 lines) ``` ### Optimizer Modes The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure: 1. **Base mode** (`base()`) — Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization. 2. **Distributed mode** (`distributed_muon()`) — Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback. 3. **Parallel mode** (`parallel()`) — Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature. ### Parallel Mode Pipeline The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through: ``` build bufs + async all2all_gather → yield → wait + Newton-Schulz compute + async all2all_scatter → yield → wait + update_param ``` The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield. `warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`. For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md). ### Key Abstractions - **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) — Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default. - **`_muon_state` dataclass** (`core.py`) — Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state. - **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables. - **`run_pipeline()`** (`async_utils.py`) — Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points. - **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) — Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+). - **Newton-Schulz iteration** (`newton_schulz.py`) — `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T. - **QK Clipping** (`qk_clip.py`) — Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`. - **Fused AdamW** (`adamw.py`) — Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement. ### Dependency Graph ``` matmul_transpose_triton.py (leaf) │ newton_schulz.py (leaf + triton) │ core.py ──── qk_clip.py (leaf, distributed/utils) │ │ │ │ pipeline.py ─── async_utils.py │ │ │ adamw.py │ │ muon.py (all above) │ __init__.py ```