Kernels
optimizer / CLAUDE.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).
## Commands
### Lint & Format
```bash
pre-commit run --all-files # Run all pre-commit hooks
pre-commit run isort --all-files # Run a specific hook (e.g., isort)
```
Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).
### Tests
Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0.
```bash
cd test && ./run_test.sh
# Equivalent to:
cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py
```
Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation).
### Build
Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`.
### Commit Convention
**Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches.
## Architecture
### Source Layout
```
torch-ext/optimizer/
β”œβ”€β”€ __init__.py # Public API: exports Muon
β”œβ”€β”€ muon.py # Muon optimizer class (~430 lines)
β”œβ”€β”€ newton_schulz.py # Newton-Schulz iteration (~50 lines)
β”œβ”€β”€ qk_clip.py # QK clipping for attention heads (~130 lines)
β”œβ”€β”€ core.py # Shared state, helpers, param grouping (~110 lines)
β”œβ”€β”€ pipeline.py # Async generator pipeline for parallel mode (~290 lines)
β”œβ”€β”€ async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines)
β”œβ”€β”€ adamw.py # Fused AdamW for non-Muon parameters (~160 lines)
β”œβ”€β”€ matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines)
└── distributed/
└── utils.py # Shard mesh construction, DTensor slicing (~175 lines)
```
### Optimizer Modes
The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:
1. **Base mode** (`base()`) β€” Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization.
2. **Distributed mode** (`distributed_muon()`) β€” Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback.
3. **Parallel mode** (`parallel()`) β€” Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature.
### Parallel Mode Pipeline
The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through:
```
build bufs + async all2all_gather β†’ yield β†’ wait + Newton-Schulz compute + async all2all_scatter β†’ yield β†’ wait + update_param
```
The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield.
`warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`.
For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md).
### Key Abstractions
- **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) β€” Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default.
- **`_muon_state` dataclass** (`core.py`) β€” Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state.
- **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.
- **`run_pipeline()`** (`async_utils.py`) β€” Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.
- **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) β€” Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+).
- **Newton-Schulz iteration** (`newton_schulz.py`) β€” `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T.
- **QK Clipping** (`qk_clip.py`) β€” Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`.
- **Fused AdamW** (`adamw.py`) β€” Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.
### Dependency Graph
```
matmul_transpose_triton.py (leaf)
β”‚
newton_schulz.py (leaf + triton)
β”‚
core.py ──── qk_clip.py (leaf, distributed/utils)
β”‚ β”‚ β”‚
β”‚ pipeline.py ─── async_utils.py
β”‚ β”‚
β”‚ adamw.py
β”‚ β”‚
muon.py (all above)
β”‚
__init__.py
```