Kernels
File size: 6,223 Bytes
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard).

## Commands

### Lint & Format

```bash
pre-commit run --all-files          # Run all pre-commit hooks
pre-commit run isort --all-files    # Run a specific hook (e.g., isort)
```

Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions).

### Tests

Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0.

```bash
cd test && ./run_test.sh
# Equivalent to:
cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py
```

Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation).

### Build

Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`.

### Commit Convention

**Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches.

## Architecture

### Source Layout

```
torch-ext/optimizer/
β”œβ”€β”€ __init__.py                    # Public API: exports Muon
β”œβ”€β”€ muon.py                        # Muon optimizer class (~430 lines)
β”œβ”€β”€ newton_schulz.py               # Newton-Schulz iteration (~50 lines)
β”œβ”€β”€ qk_clip.py                     # QK clipping for attention heads (~130 lines)
β”œβ”€β”€ core.py                        # Shared state, helpers, param grouping (~110 lines)
β”œβ”€β”€ pipeline.py                    # Async generator pipeline for parallel mode (~290 lines)
β”œβ”€β”€ async_utils.py                 # AsyncTask / AsyncRuntime scheduling (~75 lines)
β”œβ”€β”€ adamw.py                       # Fused AdamW for non-Muon parameters (~160 lines)
β”œβ”€β”€ matmul_transpose_triton.py     # Triton kernel for X @ X.T (~130 lines)
└── distributed/
    └── utils.py                   # Shard mesh construction, DTensor slicing (~175 lines)
```

### Optimizer Modes

The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure:

1. **Base mode** (`base()`) β€” Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization.
2. **Distributed mode** (`distributed_muon()`) β€” Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback.
3. **Parallel mode** (`parallel()`) β€” Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature.

### Parallel Mode Pipeline

The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through:

```
build bufs + async all2all_gather β†’ yield β†’ wait + Newton-Schulz compute + async all2all_scatter β†’ yield β†’ wait + update_param
```

The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield.

`warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`.

For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md).

### Key Abstractions

- **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) β€” Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default.
- **`_muon_state` dataclass** (`core.py`) β€” Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state.
- **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables.
- **`run_pipeline()`** (`async_utils.py`) β€” Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points.
- **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) β€” Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+).
- **Newton-Schulz iteration** (`newton_schulz.py`) β€” `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T.
- **QK Clipping** (`qk_clip.py`) β€” Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`.
- **Fused AdamW** (`adamw.py`) β€” Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement.

### Dependency Graph

```
matmul_transpose_triton.py       (leaf)
         β”‚
    newton_schulz.py              (leaf + triton)
         β”‚
      core.py ──── qk_clip.py    (leaf, distributed/utils)
       β”‚    β”‚         β”‚
       β”‚  pipeline.py ─── async_utils.py
       β”‚       β”‚
       β”‚   adamw.py
       β”‚       β”‚
      muon.py                     (all above)
         β”‚
    __init__.py
```