CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Project Overview
Optimizer is a PyTorch package implementing the Muon optimizer with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).
Commands
Lint & Format
pre-commit run --all-files # Run all pre-commit hooks
pre-commit run isort --all-files # Run a specific hook (e.g., isort)
Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).
Tests
Tests require 8 GPUs, access to Motif-Technologies/Motif-2.6B-4layer-random on HuggingFace (HF_TOKEN env var), and PyTorch >= 2.8.0.
cd test && ./run_test.sh
# Equivalent to:
cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py
Useful pytest flags: --measure-perf (timing/memory), --do-profile (profiling, requires --measure-perf), --skip-verify (skip correctness check against sequential implementation).
Build
Uses kernel-builder infrastructure (build.toml, flake.nix). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in build/.
Commit Convention
Always append [skip-build] to every commit message. This prevents CI from triggering unnecessary build jobs on development branches.
Architecture
Source Layout
torch-ext/optimizer/
βββ __init__.py # Public API: exports Muon
βββ muon.py # Muon optimizer class (~430 lines)
βββ newton_schulz.py # Newton-Schulz iteration (~50 lines)
βββ qk_clip.py # QK clipping for attention heads (~130 lines)
βββ core.py # Shared state, helpers, param grouping (~110 lines)
βββ pipeline.py # Async generator pipeline for parallel mode (~290 lines)
βββ async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines)
βββ adamw.py # Fused AdamW for non-Muon parameters (~160 lines)
βββ matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines)
βββ distributed/
βββ utils.py # Shard mesh construction, DTensor slicing (~175 lines)
Optimizer Modes
The Muon optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:
- Base mode (
base()) β Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization. - Distributed mode (
distributed_muon()) β Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback. - Parallel mode (
parallel()) β Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled byrun_pipeline(). This is the main advanced feature.
Parallel Mode Pipeline
The parallel pipeline is implemented as a single generator function muon_chunk_pipeline() in pipeline.py. Parameters are split into chunks, and each chunk flows through:
build bufs + async all2all_gather β yield β wait + Newton-Schulz compute + async all2all_scatter β yield β wait + update_param
The generator yields 2 times (after launching async gather and async scatter via async_op=True), allowing run_pipeline() to interleave multiple chunks for communication overlap. work.wait() completes each async operation after the yield.
warmup_step maps to max_concurrent_tasks = warmup_step + 1 in run_pipeline().
For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see docs/implementation.md.
Key Abstractions
get_default_muon_param_groups(model, is_muon_func)(core.py) β Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default._muon_statedataclass (core.py) β Per-parameter config: rank ownership (worker_rank), process group, precomputed shard indices (rank_indices,rank_numels), and optional QK clip state. Config-only; no transient pipeline state.muon_chunk_pipeline()generator (pipeline.py) β Processes one chunk through the full gatherβcomputeβscatterβupdate pipeline. Usesasync_op=Truefor non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.run_pipeline()(async_utils.py) β Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.construct_shard_mesh()/get_slices_of_dtensor()(distributed/utils.py) β Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles bothShardand_StridedShard(PyTorch 2.10+).- Newton-Schulz iteration (
newton_schulz.py) β_zeropower_via_newtonschulz5(): 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernelmatmul_transpose_assignfor efficient X @ X.T. - QK Clipping (
qk_clip.py) β Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured viaq_indices,k_indices,head_dim,threshold. - Fused AdamW (
adamw.py) β Uses PyTorch'storch._fused_adamw_for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.
Dependency Graph
matmul_transpose_triton.py (leaf)
β
newton_schulz.py (leaf + triton)
β
core.py ββββ qk_clip.py (leaf, distributed/utils)
β β β
β pipeline.py βββ async_utils.py
β β
β adamw.py
β β
muon.py (all above)
β
__init__.py