Kernels
optimizer / CLAUDE.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

Optimizer is a PyTorch package implementing the Muon optimizer with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).

Commands

Lint & Format

pre-commit run --all-files          # Run all pre-commit hooks
pre-commit run isort --all-files    # Run a specific hook (e.g., isort)

Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).

Tests

Tests require 8 GPUs, access to Motif-Technologies/Motif-2.6B-4layer-random on HuggingFace (HF_TOKEN env var), and PyTorch >= 2.8.0.

cd test && ./run_test.sh
# Equivalent to:
cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py

Useful pytest flags: --measure-perf (timing/memory), --do-profile (profiling, requires --measure-perf), --skip-verify (skip correctness check against sequential implementation).

Build

Uses kernel-builder infrastructure (build.toml, flake.nix). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in build/.

Commit Convention

Always append [skip-build] to every commit message. This prevents CI from triggering unnecessary build jobs on development branches.

Architecture

Source Layout

torch-ext/optimizer/
β”œβ”€β”€ __init__.py                    # Public API: exports Muon
β”œβ”€β”€ muon.py                        # Muon optimizer class (~430 lines)
β”œβ”€β”€ newton_schulz.py               # Newton-Schulz iteration (~50 lines)
β”œβ”€β”€ qk_clip.py                     # QK clipping for attention heads (~130 lines)
β”œβ”€β”€ core.py                        # Shared state, helpers, param grouping (~110 lines)
β”œβ”€β”€ pipeline.py                    # Async generator pipeline for parallel mode (~290 lines)
β”œβ”€β”€ async_utils.py                 # AsyncTask / AsyncRuntime scheduling (~75 lines)
β”œβ”€β”€ adamw.py                       # Fused AdamW for non-Muon parameters (~160 lines)
β”œβ”€β”€ matmul_transpose_triton.py     # Triton kernel for X @ X.T (~130 lines)
└── distributed/
    └── utils.py                   # Shard mesh construction, DTensor slicing (~175 lines)

Optimizer Modes

The Muon optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:

  1. Base mode (base()) β€” Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization.
  2. Distributed mode (distributed_muon()) β€” Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback.
  3. Parallel mode (parallel()) β€” Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by run_pipeline(). This is the main advanced feature.

Parallel Mode Pipeline

The parallel pipeline is implemented as a single generator function muon_chunk_pipeline() in pipeline.py. Parameters are split into chunks, and each chunk flows through:

build bufs + async all2all_gather β†’ yield β†’ wait + Newton-Schulz compute + async all2all_scatter β†’ yield β†’ wait + update_param

The generator yields 2 times (after launching async gather and async scatter via async_op=True), allowing run_pipeline() to interleave multiple chunks for communication overlap. work.wait() completes each async operation after the yield.

warmup_step maps to max_concurrent_tasks = warmup_step + 1 in run_pipeline().

For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see docs/implementation.md.

Key Abstractions

  • get_default_muon_param_groups(model, is_muon_func) (core.py) β€” Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default.
  • _muon_state dataclass (core.py) β€” Per-parameter config: rank ownership (worker_rank), process group, precomputed shard indices (rank_indices, rank_numels), and optional QK clip state. Config-only; no transient pipeline state.
  • muon_chunk_pipeline() generator (pipeline.py) β€” Processes one chunk through the full gatherβ†’computeβ†’scatterβ†’update pipeline. Uses async_op=True for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.
  • run_pipeline() (async_utils.py) β€” Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.
  • construct_shard_mesh() / get_slices_of_dtensor() (distributed/utils.py) β€” Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both Shard and _StridedShard (PyTorch 2.10+).
  • Newton-Schulz iteration (newton_schulz.py) β€” _zeropower_via_newtonschulz5(): 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel matmul_transpose_assign for efficient X @ X.T.
  • QK Clipping (qk_clip.py) β€” Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via q_indices, k_indices, head_dim, threshold.
  • Fused AdamW (adamw.py) β€” Uses PyTorch's torch._fused_adamw_ for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.

Dependency Graph

matmul_transpose_triton.py       (leaf)
         β”‚
    newton_schulz.py              (leaf + triton)
         β”‚
      core.py ──── qk_clip.py    (leaf, distributed/utils)
       β”‚    β”‚         β”‚
       β”‚  pipeline.py ─── async_utils.py
       β”‚       β”‚
       β”‚   adamw.py
       β”‚       β”‚
      muon.py                     (all above)
         β”‚
    __init__.py