diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000000000000000000000000000000000..cf12b9a281680b32e8247ba103bfe86e53d8b7ff --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,108 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Optimizer is a PyTorch package implementing the **Muon optimizer** with support for N-D sharding parallelism for large-scale distributed training. Based on the paper at https://arxiv.org/abs/2511.07464. It supports general N-D sharding configurations (FSDP2 through hybrid setups like 2 TP + 2 DP-Replicate + 2 DP-Shard). + +## Commands + +### Lint & Format + +```bash +pre-commit run --all-files # Run all pre-commit hooks +pre-commit run isort --all-files # Run a specific hook (e.g., isort) +``` + +Hooks: yapf (Python formatter), isort (import sorter), typos (spell checker), clang-format (C++/CUDA), pymarkdown (Markdown linter), actionlint (GitHub Actions). + +### Tests + +Tests require **8 GPUs**, access to `Motif-Technologies/Motif-2.6B-4layer-random` on HuggingFace (`HF_TOKEN` env var), and PyTorch >= 2.8.0. + +```bash +cd test && ./run_test.sh +# Equivalent to: +cd test && torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py +``` + +Useful pytest flags: `--measure-perf` (timing/memory), `--do-profile` (profiling, requires `--measure-perf`), `--skip-verify` (skip correctness check against sequential implementation). + +### Build + +Uses kernel-builder infrastructure (`build.toml`, `flake.nix`). Pre-built binaries for various PyTorch/CUDA/ROCm combinations are stored in `build/`. + +### Commit Convention + +**Always append `[skip-build]` to every commit message.** This prevents CI from triggering unnecessary build jobs on development branches. + +## Architecture + +### Source Layout + +``` +torch-ext/optimizer/ +├── __init__.py # Public API: exports Muon +├── muon.py # Muon optimizer class (~430 lines) +├── newton_schulz.py # Newton-Schulz iteration (~50 lines) +├── qk_clip.py # QK clipping for attention heads (~130 lines) +├── core.py # Shared state, helpers, param grouping (~110 lines) +├── pipeline.py # Async generator pipeline for parallel mode (~290 lines) +├── async_utils.py # AsyncTask / AsyncRuntime scheduling (~75 lines) +├── adamw.py # Fused AdamW for non-Muon parameters (~160 lines) +├── matmul_transpose_triton.py # Triton kernel for X @ X.T (~130 lines) +└── distributed/ + └── utils.py # Shard mesh construction, DTensor slicing (~175 lines) +``` + +### Optimizer Modes + +The `Muon` optimizer has three execution paths selected per-parameter based on its tensor type and mesh structure: + +1. **Base mode** (`base()`) — Single-device / non-sharded tensors. Standard Muon with Newton-Schulz orthogonalization. +2. **Distributed mode** (`distributed_muon()`) — Gathers full tensors via all-gather, computes updates, redistributes. Used for small parameters or fallback. +3. **Parallel mode** (`parallel()`) — Pipelined all2all communication overlapped with compute. Uses an async generator pipeline scheduled by `run_pipeline()`. This is the main advanced feature. + +### Parallel Mode Pipeline + +The parallel pipeline is implemented as a single generator function `muon_chunk_pipeline()` in `pipeline.py`. Parameters are split into chunks, and each chunk flows through: + +``` +build bufs + async all2all_gather → yield → wait + Newton-Schulz compute + async all2all_scatter → yield → wait + update_param +``` + +The generator yields 2 times (after launching async gather and async scatter via `async_op=True`), allowing `run_pipeline()` to interleave multiple chunks for communication overlap. `work.wait()` completes each async operation after the yield. + +`warmup_step` maps to `max_concurrent_tasks = warmup_step + 1` in `run_pipeline()`. + +For detailed implementation documentation (pipeline internals, distributed utilities, QK clipping with strided sharding, etc.), see [`docs/implementation.md`](docs/implementation.md). + +### Key Abstractions + +- **`get_default_muon_param_groups(model, is_muon_func)`** (`core.py`) — Separates parameters into Muon-optimizable (2D+) and AdamW groups. Skips embeddings and output layers by default. +- **`_muon_state` dataclass** (`core.py`) — Per-parameter config: rank ownership (`worker_rank`), process group, precomputed shard indices (`rank_indices`, `rank_numels`), and optional QK clip state. Config-only; no transient pipeline state. +- **`muon_chunk_pipeline()` generator** (`pipeline.py`) — Processes one chunk through the full gather→compute→scatter→update pipeline. Uses `async_op=True` for non-blocking all-to-all and yields to allow chunk interleaving. All intermediate buffers are generator-local variables. +- **`run_pipeline()`** (`async_utils.py`) — Generator-based pipeline scheduling with bounded concurrency. Interleaves multiple chunk pipelines at yield points. +- **`construct_shard_mesh()` / `get_slices_of_dtensor()`** (`distributed/utils.py`) — Utilities for building shard meshes from DTensor placements and computing per-rank local slices. Handles both `Shard` and `_StridedShard` (PyTorch 2.10+). +- **Newton-Schulz iteration** (`newton_schulz.py`) — `_zeropower_via_newtonschulz5()`: 5 quintic iterations in bfloat16 with pre-optimized coefficients for gradient orthogonalization. Uses Triton kernel `matmul_transpose_assign` for efficient X @ X.T. +- **QK Clipping** (`qk_clip.py`) — Optional dynamic clipping of attention head projections when QK logits exceed a threshold. Configured via `q_indices`, `k_indices`, `head_dim`, `threshold`. +- **Fused AdamW** (`adamw.py`) — Uses PyTorch's `torch._fused_adamw_` for non-Muon parameters, grouping tensors by device/dtype and DTensor placement. + +### Dependency Graph + +``` +matmul_transpose_triton.py (leaf) + │ + newton_schulz.py (leaf + triton) + │ + core.py ──── qk_clip.py (leaf, distributed/utils) + │ │ │ + │ pipeline.py ─── async_utils.py + │ │ + │ adamw.py + │ │ + muon.py (all above) + │ + __init__.py +``` diff --git a/README.md b/README.md index 59d1c3567a56a6978c0a714c956cf845e443fda8..5416fc8ec3ca3469a807f04d2622d88a39a1c64c 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,13 @@ optim = optimizer.Muon( ) ``` +## Documentation + +- [Implementation Guide](./docs/implementation.md) — Detailed walkthrough of the internal architecture, parallel pipeline, distributed utilities, and QK clipping. Recommended for code reviewers and new contributors. +- [PyTorch 2.10 TP Fix](./docs/pytorch-2.10-tp-fix.md) — Root cause analysis and fixes for `_StridedShard` compatibility with PyTorch 2.10+. + ## Test + - Check [test/README.md](./test/README.md) for how to run the tests. ## Pre-commit Hooks diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 6015e5b4ea5da27e0002b298d9a1ab55142f88ab..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf -size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..96a6868d0ec423b37d2097f2a60061a3b90efc70 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f095be87ff6185010a3cff4175abbde0b2e50fe1e435dc1db4eaf5bf1f6199ca +size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/adamw.py b/build/torch210-cxx11-cu126-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/async_utils.py b/build/torch210-cxx11-cu126-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch210-cxx11-cu126-x86_64-linux/core.py b/build/torch210-cxx11-cu126-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/muon.py b/build/torch210-cxx11-cu126-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a2b4992c68bd2d564fa8ac804bce7a9f9d0a09d9..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0 -size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1d1806041a1fdcea027e6aa31eb8b774c6c797d0 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4919c48c77c6223dbf668f1461bcec175ef1bd6ea4cec8c2509de12ca7200a62 +size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/adamw.py b/build/torch210-cxx11-cu128-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/async_utils.py b/build/torch210-cxx11-cu128-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch210-cxx11-cu128-x86_64-linux/core.py b/build/torch210-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/muon.py b/build/torch210-cxx11-cu128-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 62bbc727da9606819a23c43dda20add2be7c1fe3..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2 -size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..08caf42e7e7b1f311490df8058ed06d87ea79358 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9c7bb12bc030d4959e880a959b39ea07eb03e16175d7cf03829f9860f52525d +size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/adamw.py b/build/torch210-cxx11-cu130-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/async_utils.py b/build/torch210-cxx11-cu130-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch210-cxx11-cu130-x86_64-linux/core.py b/build/torch210-cxx11-cu130-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/muon.py b/build/torch210-cxx11-cu130-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a2bbc913106abe6d784d7634ad119d969ff23a3c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497 -size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..49889967591405cc5266af4e0911e0895d7b309b --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00e9d9e1c2306badb97c3b8f2454a47d6335a302101a38c804ad3c7b075168cc +size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py b/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py b/build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/core.py b/build/torch210-cxx11-rocm70-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py b/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index ed70a8ee48aca9da47db195b5e73c86aca32b153..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d804ba4d3ed9716c80e9819ba16a2bef300fb23fa4c456c550f4a96167a2eb00 -size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..10d8f0e7de3adaf54aa7478421c25a02e409544e --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e67022789ddd9296552fc5ab4075ce96b8b00b75bce057c707e5b5076bbde734 +size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py b/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/async_utils.py b/build/torch210-cxx11-rocm71-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/core.py b/build/torch210-cxx11-rocm71-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py b/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index a218cd77694938fb0914270a5c6416a684d50cb3..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:222315672693e6d4544b1eee4772dc7be744b3794cfd6ff370a6f46d782386a1 -size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1ccf0dbda4220efff722d4b971b23b40592c3a81 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee0ac60d2f40d1feb67e804e6b1024844d8cbbf5c62d6d014621a40dc6b3afc3 +size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/adamw.py b/build/torch28-cxx11-cu126-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/async_utils.py b/build/torch28-cxx11-cu126-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch28-cxx11-cu126-x86_64-linux/core.py b/build/torch28-cxx11-cu126-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 1cf60567b59ce1b343c5a44301e443953b674f78..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:119adc22cd57de6d6d78c1f5094310b57083050f40836a5455bdb6c35bed104b -size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2a7b540994e8d72dfccead970e2fe685f976d2ae --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89137de30694bc0ad3165d1a998c801151370290ed1217f343409b11a8f74908 +size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/adamw.py b/build/torch28-cxx11-cu128-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/async_utils.py b/build/torch28-cxx11-cu128-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch28-cxx11-cu128-x86_64-linux/core.py b/build/torch28-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index e996c45edb033c93ec8a41716764cdcbbd04593d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7e8463be5f48aba32d645183945d258cdb532b238ef40665db396b459367cad1 -size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9e281900c03ffb5f3513aa19cc4f0f48e8a90cae --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f5d04e35a6d7a64d44ba42590c3ef930535c6100498d9c4bc28deb50c819a8d +size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/adamw.py b/build/torch28-cxx11-cu129-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/async_utils.py b/build/torch28-cxx11-cu129-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch28-cxx11-cu129-x86_64-linux/core.py b/build/torch28-cxx11-cu129-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 19ee075424c40e1714e4ef6561d68c368e933792..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:90ac494e1381bedf95832a91c108ff18d900442203f9b0612efa5519956def2e -size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..885ac14b4c5469770fdeaf3766d4c28aa25ada8a --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3fcf69ab6e1e6d7732b6b887350af98666ada6909773898d6b2c8efa56c4cd0 +size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py b/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/async_utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/core.py b/build/torch28-cxx11-rocm63-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json +++ b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py b/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index d23cf944ec31a3606755cdac0f39bae6455816d5..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5ddeadf7e678e0ff7e84b9e4f869ef45ed6840b06e9093e20210769fd15b8cad -size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6ec327ad391829e41a0a5dc05568e90ac77781b0 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc94ac631623c7169f42b8c21066b4cf03ef892078269fe0c4318634b9c08912 +size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py b/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/async_utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/core.py b/build/torch28-cxx11-rocm64-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json +++ b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py b/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index ca73c2a576e1ad27e2c5a403c459246792b9a9d1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:07135e56b4c66b79fcb062c0bd39e61dae7e4251f164638cd09f8e360075f215 -size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..44ca420ee062544acac81ece75a66953807a4502 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb807f26eac961830776950d2bad9ef96838705fcdf5be8c5ee6dc9c18e0c3a4 +size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/adamw.py b/build/torch29-cxx11-cu126-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/async_utils.py b/build/torch29-cxx11-cu126-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch29-cxx11-cu126-x86_64-linux/core.py b/build/torch29-cxx11-cu126-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index b4ccc5bd24c68e412968b43af9a352dd5ac27863..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f048516a9820c335263f335df545e404e22ee146355b49669c95a54852448542 -size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..81304f1e72844f803f036498f2b7bad16a5d60c1 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5729d5d70fb41aa7eb7ae7fa095c6f6765a0119ac70d0c3139fc31357f4abe78 +size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/adamw.py b/build/torch29-cxx11-cu128-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/async_utils.py b/build/torch29-cxx11-cu128-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch29-cxx11-cu128-x86_64-linux/core.py b/build/torch29-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 67ccafc522c41f14eaf682f265f2bc7d3f56b114..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ef9fba09368a2296ebad017f6576f119ebe2b9513be0d51b66b403fe942bb6d5 -size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..cad267f9451b926dc53837595c5ec843476dc560 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09573102bbde35675944ee02dacd2bbad50fc6f151816a6814ef5651adf40e69 +size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/adamw.py b/build/torch29-cxx11-cu130-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/async_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch29-cxx11-cu130-x86_64-linux/core.py b/build/torch29-cxx11-cu130-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 926869eca5ee9c6a8f6899f3966ba361bc640faa..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c1574fefc74653a663d8c4c53dda381d92c60cdc29358f15618b1b746dc4ae4e -size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..75edac5d5f08066ba4f74df9fb2c6b740d65e613 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40563a27767823176595fede23009b17b26e6b2c6a5847e255448d51da70b854 +size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py b/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/async_utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/core.py b/build/torch29-cxx11-rocm63-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json +++ b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py b/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py index e6f6fcf6280e969b1761926112147d3146e27b59..b34ab4955d83942fd070363fe79547a36deb1742 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_06a260a_dirty -ops = torch.ops._optimizer_06a260a_dirty +from . import _optimizer_7aef62f_dirty +ops = torch.ops._optimizer_7aef62f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so deleted file mode 100755 index 95d54a0288c1e9cea520f5e3042a163cb9222346..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6ad69fa088ef05b1697f74d59c1a5a12f17dbf2a3cddb8c6b92ed7543b4cbdbc -size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..db1924b3f25a792a5aa5de6db2005cac974da79a --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae0556a81551f05fff0b83b1924c55a70e399c29171f9f7ce1bd63ccb24fc417 +size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py b/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/async_utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/core.py b/build/torch29-cxx11-rocm64-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json index 76bafa5f33b6818aa6bb4cab04be811b87519b44..c55a35717622f1dd5c8ba376ea3a814cbcc10d78 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json +++ b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json @@ -1 +1,3 @@ -{"python-depends":[]} \ No newline at end of file +{ + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py b/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) diff --git a/docs/expert_parallel.md b/docs/expert_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..c037a998bc44ffea02503dacf41f7124508a1612 --- /dev/null +++ b/docs/expert_parallel.md @@ -0,0 +1,264 @@ +# Expert Parallelism in torchtitan + +torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서. +Muon optimizer의 MoE 지원에 필요한 배경 지식. + +Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py` + +## Overview + +torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공: + +| Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch | +|--------|----|----|-----|--------------------------|----------------| +| TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None | +| EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all | +| EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP | +| EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP | + +Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2). + +## EP가 dp_shard를 빌리는 구조 + +EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용: + +``` +dp_shard = dp_shard_mod_ep * dp_shard_in_ep + +ETP=TP일 때: ep = dp_shard_in_ep * cp +ETP=1일 때: ep = dp_shard_in_ep * cp * tp +``` + +기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시: + +``` +[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp] +``` + +로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성). + +### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1 + +``` +dp_shard_in_ep = ep / cp = 4 +dp_shard_mod_ep = dp_shard * cp / ep = 2 + +mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4] +EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배 +FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard +``` + +## Submesh 매핑 + +```python +# Data loading (no communication) +dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep] + +# Non-expert parameter sharding (FSDP) +dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp] + +# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외 +dp_mod_ep = [dp_replicate?, dp_shard_mod_ep] + +# Expert parallelism mesh +ep = [dp_shard_in_ep, cp, (tp if etp==1)] + +# Loss all-reduce +dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp] +``` + +## 4가지 전략 상세 + +### 1. TensorParallel (TP Only, EP=1) + +EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding: + +```python +# expert_parallel.py: TensorParallel +w1: [Shard(1)] on TP mesh # column-wise (out_dim) +w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2) +w3: [Shard(1)] on TP mesh # column-wise (out_dim) +``` + +Token dispatch 없음. 일반 TP와 동일하게 동작. + +### 2. ExpertParallel (EP Only, TP=1) + +Expert dim (dim 0)으로 sharding. Token all-to-all dispatch: + +```python +# expert_parallel.py: ExpertParallel +w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배 +``` + +Forward pass: +1. Router가 각 token을 expert에 할당 +2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch +3. 각 rank가 local expert에서 compute +4. `all_to_all_single`으로 결과를 원래 rank로 combine + +### 3. ExpertTensorParallel (EP+TP, ETP=TP) + +EP와 TP를 동시에 2D로 적용: + +```python +# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel) +w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column +w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row +w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column +``` + +Token dispatch: +1. TP mesh에서 input을 Replicate (gradient는 Partial) +2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일) +3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리 + +### 4. ReordererSequenceParallel (EP+TP, ETP=1) + +TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작: + +```python +# expert_parallel.py: ReordererSequenceParallel +# Expert weights: [Shard(0)] on EP mesh (TP 안 씀) +# Token split: batch*seq_len을 TP rank 수로 나눠서 분배 + +# EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨 +``` + +TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음. + +## EFSDP (Expert FSDP) + +Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용: + +```python +# parallelize.py: apply_fsdp +# Non-expert: dp_shard_cp mesh 전체로 shard +fully_shard(transformer_block, mesh=dp_shard_cp_mesh) + +# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard +# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외 +fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh) +``` + +### Dynamic shard placement + +Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때), +dim 0 대신 dim 1로 shard. + +**torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`): + +```python +# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). +# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding +# causes inefficiency, so we choose to do FSDP sharding on dim-1. +_experts_shard_placement_fn = None +if ( + dp_mod_ep_mesh.size() * ep_degree + > transformer_block.moe.experts.num_experts +): + _experts_shard_placement_fn = lambda param: Shard(1) + +fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, +) +``` + +`dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`): + +```python +dp_mod_ep_mesh_dim_names = [] +if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") +# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] +``` + +### 실제 placement 검증 결과 + +8 GPUs, `num_experts=2`, `etp=1` 기준: + +#### num_experts=8 (기본) + +모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**: + +| Config | Expert Placements | Mesh | +|--------|-------------------|------| +| ep=8 | `[Shard(0)]` | `[ep=8]` | +| ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` | +| ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | +| ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | + +EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음. + +#### num_experts=2 (expert 수 < EFSDP shard count) + +`dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**: + +| Config | 조건 | Expert Placements | Mesh | +|--------|------|-------------------|------| +| ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | +| ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | + +- EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704) +- EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1) +- `_StridedShard`가 아닌 일반 `Shard` 사용 + +## Gradient Clipping with EP + +EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산: + +```python +# distributed/utils.py: _clip_grad_norm_with_ep +ep_norm = get_total_norm(ep_grads, ...) +non_ep_norm = get_total_norm(non_ep_grads, ...) +total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p) +``` + +EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부. + +## Muon optimizer에서의 처리 + +현재 Muon optimizer의 MoE 지원: + +1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장 +2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap + - 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh +3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지 + +### Muon이 지원하는 config + +| Config | 지원 | 비고 | +|--------|------|------| +| TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 | +| EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) | +| FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim | +| HSDP + TP | O | Replicate + FSDP + TP | +| EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor | +| EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 | +| EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` | +| EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 | + +### EFSDP Shard(1)과 Muon의 호환성 + +Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이 +TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic): + +``` +3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2] + local shape: (1, 704, 2048) + +_expand_expert_params: + 1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep + 2. submesh 추출 → dp_shard_mod_ep (1D, size 4) + 3. dim 0 split → (704, 2048) + 4. DTensor wrap → Shard(0) on dp_shard_mod_ep + = 일반 FSDP sharded 2D 텐서와 동일 + +→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리. + construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음). +``` diff --git a/docs/implementation.md b/docs/implementation.md new file mode 100644 index 0000000000000000000000000000000000000000..25e158fdd5072e9d1287f55f0163bd1ed8235d70 --- /dev/null +++ b/docs/implementation.md @@ -0,0 +1,277 @@ +# Muon Optimizer: Implementation Guide + +This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities. + +## Table of Contents + +1. [Overview](#overview) +2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing) +3. [Execution Paths](#execution-paths) +4. [Parallel Pipeline (the core feature)](#parallel-pipeline) +5. [Distributed Utilities](#distributed-utilities) +6. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization) +7. [QK Clipping](#qk-clipping) +8. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters) +9. [Source File Map](#source-file-map) + +--- + +## Overview + +Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU. + +The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity. + +## Entry Point and Parameter Routing + +**File:** `muon.py` — `Muon.step()` / `Muon._step_muon()` + +Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step: + +1. **Non-Muon groups** → `step_adamw()` (fused AdamW). +2. **Muon groups** → `_step_muon()`, which further classifies each parameter: + +``` +_step_muon(group) + | + +-- DTensor, all Replicate placements --> base() (no sharding) + +-- DTensor, numel <= threshold --> distributed_muon() (small param fallback) + +-- DTensor, sharded --> parallel() (pipelined all-to-all) + +-- plain Tensor --> base() (single device) +``` + +Parameters are classified by their DTensor placements: +- **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon. +- **Small sharded** DTensors (below `small_param_numel_threshold`, default 65536) use `distributed_muon()` — gathers the full tensor via `full_tensor()`, computes the update, then redistributes. +- **Large sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below. + +## Execution Paths + +### base() — Single Device + +Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping. + +### distributed_muon() — Full Gather + +Each parameter's gradient is gathered to full via `g.full_tensor()`, orthogonalized on every rank, then the updated full parameter is redistributed back to the original sharded placement. Simple but communication-heavy — used only as a fallback for small parameters. + +### parallel() — Pipelined All-to-All + +This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation. + +## Parallel Pipeline + +### Design Motivation + +Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is: + +1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all. +2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient. +3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all. +4. **Update**: Each rank applies weight decay and the update to its local shard. + +To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently. + +### Architecture + +``` +muon.py: parallel() + | + +-- init_state_and_assign_params() -- assigns ownership, precomputes indices + | + +-- pipelines() generator -- yields muon_chunk_pipeline() per chunk + | + +-- run_pipeline(pipelines, max_concurrent=warmup_step+1) + | + +-- interleaves chunks at yield boundaries +``` + +### The Chunk Pipeline Generator + +**File:** `pipeline.py` — `muon_chunk_pipeline()` + +Each chunk is a generator that yields **2 times**, creating stages separated by async communication: + +``` + YIELD 1 YIELD 2 + | | +[Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params] +``` + +- **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield. +- **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own. + +### The Pipeline Scheduler + +**File:** `async_utils.py` — `run_pipeline()` + +A simple round-robin scheduler: + +```python +while have_new or previous_tasks: + # Admit one new pipeline if below concurrency limit + if have_new and len(previous_tasks) < max_concurrent: + task = next(pipelines) # runs to first yield + # Advance all existing tasks by one yield + for task in previous_tasks: + task.step() # runs to next yield +``` + +`max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap. + +### Ownership Assignment + +**File:** `muon.py` — `init_state_and_assign_params()` + +Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks. + +### Precomputed Shard Indices + +Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`: + +```python +@dataclass +class _muon_state: + worker_rank: int # which rank owns this param's computation + process_group: ProcessGroup # the all-to-all communication group + rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor + rank_numels: dict[int, int] # rank -> number of elements in shard + name: str + qk_clip_state: QKClipInfo | None +``` + +`rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages. + +### Pipeline Stages in Detail + +#### Stages 1-2: Gather + +1. **Allocate** receive buffers for gathered gradients (only on owning ranks). +2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank. +3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather. +4. **Yield 1**: Other chunks can launch their gather while this one waits. +5. **`work.wait()`**: Complete the gather. +6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`. + +#### Stage 3: Compute + +The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work. + +#### Stages 4-5: Scatter + +Inverse of gather: +1. **Allocate** receive buffers for the orthogonalized update `U`. +2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank. +3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter. +4. **Yield 2**: Other chunks can launch their scatter while this one waits. +5. **`work.wait()`**: Complete the scatter. +6. **Copy** received shards into local update buffers. + +#### Stage 6: Update + +Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured. + +## Distributed Utilities + +**File:** `distributed/utils.py` + +These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns. + +### `construct_shard_mesh(placements, mesh)` + +Given a DTensor's placements and device mesh, this function: + +1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` after regular `Shard` on the same dim). +2. **Permutes** the mesh accordingly. +3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh. +4. **Creates** a ProcessGroup for the current rank's shard mesh. + +Returns `(shard_mesh, process_group, shard_placements)` — used for all-to-all communication. + +**Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them. + +### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)` + +Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension. + +Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension. + +**Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks: +- Rank 0 might own rows `[0, 4, 8, 12]` (strided) +- Rank 1 might own rows `[1, 5, 9, 13]` +- etc. + +### PyTorch 2.10 Compatibility + +In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies: + +```python +def _is_shard(placement): + return isinstance(placement, (Shard, _StridedShard)) +``` + +## Newton-Schulz Orthogonalization + +**File:** `newton_schulz.py` + +`_zeropower_via_newtonschulz5()` computes the orthogonal approximation of a matrix using 5 quintic Newton-Schulz iterations with pre-optimized coefficients. The result approximates `US'V^T` where `S'` is near-uniform on `[0.5, 1.5]`, which empirically does not hurt model performance vs. exact `UV^T`. + +Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency. + +**File:** `matmul_transpose_triton.py` + +The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring. + +## QK Clipping + +**File:** `qk_clip.py` + +Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`. + +**In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads: + +```python +# pipeline.py: _update_params() +ratio = p.shape[0] // scales_full.shape[0] # rows per head +idx0 = state.rank_indices[rank][0] # which global rows this rank owns +row_scales = scales_full[idx0 // ratio] # map each row to its head's scale +p._local_tensor.mul_(row_scales.view(-1, 1)) +``` + +## AdamW for Non-Muon Parameters + +**File:** `adamw.py` + +Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call. + +## Source File Map + +| File | Lines | Purpose | +|------|-------|---------| +| `muon.py` | ~525 | Optimizer class, parameter routing, 3 execution paths | +| `pipeline.py` | ~290 | Generator-based parallel pipeline (gather/compute/scatter/update) | +| `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency | +| `core.py` | ~110 | `_muon_state` dataclass, momentum/update helpers, param grouping | +| `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation | +| `newton_schulz.py` | ~50 | Newton-Schulz iteration | +| `matmul_transpose_triton.py` | ~120 | Triton kernel for symmetric matmul | +| `qk_clip.py` | ~130 | QK logit clipping | +| `adamw.py` | ~160 | Fused AdamW for non-Muon params | + +### Dependency Graph + +``` +matmul_transpose_triton.py (leaf) + | + newton_schulz.py (leaf + triton) + | + core.py ---- qk_clip.py (leaf, distributed/utils) + | | | + | pipeline.py --- async_utils.py + | | + | adamw.py + | | + muon.py (all above) + | + __init__.py +``` diff --git a/docs/pytorch-2.10-tp-fix.md b/docs/pytorch-2.10-tp-fix.md new file mode 100644 index 0000000000000000000000000000000000000000..3533cad5e5c5cfcc63d99c7d94ad2577db182646 --- /dev/null +++ b/docs/pytorch-2.10-tp-fix.md @@ -0,0 +1,151 @@ +# PyTorch 2.10 Tensor Parallelism Fix + +## Summary + +PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our +`distributed/utils.py` code that handles DTensor sharding. This document +records the root cause, every change made so far, and the one remaining issue. + +--- + +## 1. Root Cause: `_StridedShard` class hierarchy change + +| Version | MRO | +|---------|-----| +| PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` | +| PyTorch 2.10 | `_StridedShard -> StridedShard -> Placement` | + +**Consequences:** + +- `isinstance(strided_shard, Shard)` returns `False` +- `strided_shard.is_shard()` returns `False` +- Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported + placement and raised `AssertionError`. + +### When does `_StridedShard` appear? + +When two parallelism dimensions shard the same tensor dimension. +For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and +FSDP shard dimension 0 of Q/K/V projection weights: + +``` +TP : Shard(0) → each TP rank gets 2048/4 = 512 rows +FSDP: Shard(0) on top → each FSDP rank further splits those rows +``` + +PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)` +to indicate non-contiguous (interleaved) row ownership. + +--- + +## 2. Completed Fixes + +### 2.1 `_is_shard()` helper (`distributed/utils.py`) + +Added a helper that correctly identifies both `Shard` and `_StridedShard`: + +```python +def _is_shard(placement: Placement) -> bool: + return isinstance(placement, (Shard, _StridedShard)) +``` + +Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`. + +### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`) + +Old code assumed contiguous slicing (`start = rank * shard_size`), which is +wrong for `_StridedShard`. + +New code uses PyTorch's own offset-computation methods: + +| Placement type | API used | +|----------------|----------| +| `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)` → `(size, offset)` | +| `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)` → `(size, offsets_list)` | + +Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`: +- `slice` for contiguous ranges (Shard or contiguous StridedShard result) +- `torch.LongTensor` of indices for non-contiguous ranges + +Composed sharding (multiple placements on the same dim) is handled by +indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`. + +### 2.3 Updated `numel_for_rank()` (`core.py`) + +Now handles both `slice` and `torch.Tensor` index types: + +```python +for idx, dim_size in zip(indices, param.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) +``` + +### 2.4 Updated pipeline stages (`pipeline.py`) + +- **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index + assignment) instead of view-based copy, works for both slice and tensor + indices. +- **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types. +- **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor` + directly instead of through DTensor operations, avoiding sharding + propagation errors with `_StridedShard`. + +--- + +## 3. Test Results After Fixes + +| Test configuration | Status | +|--------------------|--------| +| base | PASS | +| fsdp | PASS | +| hsdp | PASS | +| tp | PASS | +| hsdp+tp (QK clip off) | PASS | +| hsdp+tp (QK clip on) | PASS | +| fsdp+tp (QK clip off) | PASS | +| fsdp+tp (QK clip on) | PASS | + +All 24 tests pass (126s, `--skip-verify`). + +--- + +## 4. Fixed: QK Clipping with Strided Sharding + +### Problem + +With strided (non-contiguous) sharding, local rows are **interleaved across +multiple heads**. For example with `fsdp+tp` (`dp_shard=2, tp=4`): + +- Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads` +- Each rank owns 256 rows, but they span 4 heads with 64 rows per head +- `view(-1, head_dim, cols)` assumes contiguous head blocks → wrong for + interleaved rows → shape mismatch error + +### Fix Applied + +For strided sharding, apply scales **per-row** based on each row's global +head index instead of using the head-block view: + +```python +if isinstance(weight_indices[0], slice): + # Contiguous case: view-based approach still works + ... +else: + # Strided case: per-row scaling + head_per_row = weight_indices[0] // ratio + row_scales = scales_full[head_per_row] + local_p.mul_(row_scales.view(-1, 1)) +``` + +--- + +## 5. Files Modified + +| File | Changes | +|------|---------| +| `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` | +| `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices | +| `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping | diff --git a/test/conftest.py b/test/conftest.py index 15177262eb39e8f60c95742bb372faf2f3857ae9..194f1d35d4981659e0e0c37dfe222dda1fa6ea73 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -122,3 +122,57 @@ def inputs(): } return [model, grads, qk_logits] + + +def _create_moe_model(num_experts=8, top_k=2, n_layers=4): + """Create a torchtitan Llama4 MoE model with random gradients.""" + from torchtitan.models.llama4.model.args import TransformerModelArgs + from torchtitan.models.llama4.model.model import Transformer + from torchtitan.models.moe import MoEArgs + + torch.manual_seed(SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + moe_args = MoEArgs( + num_experts=num_experts, + num_shared_experts=1, + top_k=top_k, + score_func="sigmoid", + ) + model_args = TransformerModelArgs( + dim=2048, + n_layers=n_layers, + n_heads=16, + n_kv_heads=8, + vocab_size=32000, + norm_eps=1e-5, + rope_theta=10000, + max_seq_len=4096, + moe_args=moe_args, + interleave_moe_layer_step=1, + ) + model = Transformer(model_args) + model.init_weights() + logger.info(f"Created torchtitan Llama4 MoE model " + f"(num_experts={num_experts}, n_layers={n_layers}, " + f"{len(list(model.parameters()))} parameters)") + + grads = [ + torch.randn_like(param, device=param.device, dtype=param.dtype) + for param in model.parameters() + ] + + return [model, grads] + + +@pytest.fixture(scope="session") +def moe_inputs(): + """MoE model with 8 experts (standard config).""" + return _create_moe_model(num_experts=8, top_k=2) + + +@pytest.fixture(scope="session") +def moe_inputs_few_experts(): + """MoE model with 2 experts (triggers EFSDP Shard(1) mode).""" + return _create_moe_model(num_experts=2, top_k=1) diff --git a/test/run_test_moe.sh b/test/run_test_moe.sh new file mode 100755 index 0000000000000000000000000000000000000000..e9e5f4fc1e22be35f8071e20c1ce0bae6cb6bd85 --- /dev/null +++ b/test/run_test_moe.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +cd "$(dirname "$0")" +torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon_moe.py "$@" diff --git a/test/test_muon.py b/test/test_muon.py index 3c4085963941120b0c089bfbdfad3a840c00da20..cfa5c1245b9ce5371f1f695851674bed88bf3845 100644 --- a/test/test_muon.py +++ b/test/test_muon.py @@ -28,6 +28,7 @@ def apply_muon_step( use_distributed_muon: bool = False, measure_perf: bool = False, do_profile: bool = False, + test_name: str | None = None, ) -> tuple[torch.nn.Module, tuple[float, float] | None]: """ apply single Muon step with optional QK clipping """ @@ -81,9 +82,9 @@ def apply_muon_step( start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) + torch.cuda.reset_peak_memory_stats() start.record() num_iters = 20 - current_mem = torch.cuda.memory_allocated() if do_profile: context = profile( @@ -99,19 +100,13 @@ def apply_muon_step( end.record() end.synchronize() - if prof is not None and dist.get_rank() == 0: + if prof is not None: date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) - profile_name = "trace" - profile_name += f"_{date}" - profile_name += f"_{parallel_dims}" - profile_name += f"_{chunk_size}" - profile_name += f"_{warmup_step}" - profile_name += f"_{qk_logits is not None}" - profile_name += f"_{use_distributed_muon}" + name = test_name or "trace" + rank = dist.get_rank() + prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") - prof.export_chrome_trace(f"{profile_name}.json") - - peak_memory = torch.cuda.max_memory_allocated() - current_mem + peak_memory = torch.cuda.max_memory_allocated() elapsed_time_ms = start.elapsed_time(end) / num_iters @@ -159,7 +154,7 @@ def sequential_muon_result( OVERLAP_STEPS = [5] -CHUNK_SIZES = [8] +CHUNK_SIZES = [2] SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000] @@ -222,6 +217,7 @@ def test_parallel_muon( use_distributed_muon=use_distributed_muon, measure_perf=measure_perf, do_profile=do_profile, + test_name=request.node.name, ) if measure_perf: diff --git a/test/test_muon_moe.py b/test/test_muon_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4cb8d7c9618965bdfb1883f6c58e2211100da1 --- /dev/null +++ b/test/test_muon_moe.py @@ -0,0 +1,294 @@ +import copy +import logging +import time +from contextlib import nullcontext + +import pytest +import torch +import torch.distributed as dist +from optimizer.muon import Muon, get_default_muon_param_groups +from torch.distributed.tensor import DTensor, Replicate +from torch.profiler import ProfilerActivity, profile + +from .utils import ParallelDims, assert_params_equal, parallelize_llama4 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def _apply_grads(model, grads): + """Apply gradients to model parameters (with DTensor redistribute).""" + for grad, param in zip(grads, model.parameters()): + grad = grad.to(param.device) + if isinstance(param.data, DTensor): + unsharded_grad = DTensor.from_local( + grad, + device_mesh=param.data.device_mesh, + placements=[Replicate()] * param.data.device_mesh.ndim, + ) + param.grad = unsharded_grad.redistribute( + device_mesh=param.data.device_mesh, + placements=param.data.placements) + else: + param.grad = grad + + +def _restore_grads(model, saved_grads): + """Restore previously saved grads (no redistribute, just reassign).""" + for param, g in zip(model.parameters(), saved_grads): + param.grad = g + + +def apply_muon_step_moe( + model: torch.nn.Module, + parallel_dims: ParallelDims | None, + grads: list[torch.Tensor], + warmup_step: int, + chunk_size: int, + small_param_numel_threshold: int, + use_distributed_muon: bool = False, + measure_perf: bool = False, + do_profile: bool = False, + test_name: str | None = None, +) -> tuple[torch.nn.Module, tuple[float, float] | None]: + """Apply a single Muon step to an MoE model (no QK clipping).""" + + assert len(grads) == len(list(model.parameters())) + _apply_grads(model, grads) + + params = get_default_muon_param_groups(model, expert_keys=["experts"]) + optim = Muon( + params=params, + clip_config=None, + none_grad=False, + warmup_step=warmup_step, + chunk_size=chunk_size, + small_param_numel_threshold=small_param_numel_threshold, + use_distributed_muon=use_distributed_muon, + expert_keys=["experts"], + ) + + # Save sharded grads for re-use before step clears 3D grads. + saved_grads = [p.grad for p in model.parameters()] + + optim.step() + + timing_result: tuple[float, float] | None = None + + if measure_perf: + # extra warm up + _restore_grads(model, saved_grads) + optim.step() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats() + start.record() + num_iters = 20 + + if do_profile: + context = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True) + else: + context = nullcontext() + + with context as prof: + for _i in range(num_iters): + _restore_grads(model, saved_grads) + optim.step() + + end.record() + end.synchronize() + + if prof is not None: + date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + name = test_name or "trace_moe" + rank = dist.get_rank() + prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") + + peak_memory = torch.cuda.max_memory_allocated() + elapsed_time_ms = start.elapsed_time(end) / num_iters + timing_result = (elapsed_time_ms, peak_memory) + + return model, timing_result + + +@pytest.fixture(scope="session") +def sequential_moe_result( + skip_verify, + moe_inputs, +) -> torch.nn.Module | None: + """Run Muon optimizer on sequential MoE model for baseline.""" + if skip_verify: + logger.info("Skipping verification tests as per user request") + return None + + model, grads = moe_inputs + + result, _ = apply_muon_step_moe( + model=copy.deepcopy(model).cuda(), + parallel_dims=None, + grads=grads, + warmup_step=-1, + chunk_size=-1, + small_param_numel_threshold=-1, + ) + result = result.cpu() + + return result + + +OVERLAP_STEPS = [5] +CHUNK_SIZES = [2] +SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000] + + +@pytest.mark.parametrize("parallel_dims", [ + pytest.param(ParallelDims(8, 1, 1), id="base"), + pytest.param(ParallelDims(1, 8, 1), id="fsdp"), + pytest.param(ParallelDims(2, 4, 1), id="hsdp"), + pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), + pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), + pytest.param(ParallelDims(1, 1, 1, ep_degree=8), id="ep"), + pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="ep+fsdp"), + pytest.param(ParallelDims(1, 2, 1, ep_degree=4), id="ep4+fsdp"), + pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="ep+hsdp"), +]) +@pytest.mark.parametrize("use_distributed_muon", [False]) +@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) +@pytest.mark.parametrize("chunk_size", CHUNK_SIZES) +@pytest.mark.parametrize("small_param_numel_threshold", + SMALL_PARAM_NUMEL_THRESHOLDS) +def test_parallel_muon_moe( + request, + sequential_moe_result: torch.nn.Module | None, + parallel_dims: ParallelDims, + use_distributed_muon: bool, + warmup_step: int, + chunk_size: int, + small_param_numel_threshold: int, + moe_inputs: tuple[torch.nn.Module, list[torch.Tensor]], + measure_perf, + do_profile, +) -> None: + model, grads = moe_inputs + + # Deepcopy the model to avoid in-place modification + model = copy.deepcopy(model).cuda() + + parallelized_model = parallelize_llama4(model, parallel_dims) + + parallelized_model, timing_result = apply_muon_step_moe( + model=parallelized_model, + parallel_dims=parallel_dims, + grads=grads, + warmup_step=warmup_step, + chunk_size=chunk_size, + small_param_numel_threshold=small_param_numel_threshold, + use_distributed_muon=use_distributed_muon, + measure_perf=measure_perf, + do_profile=do_profile, + test_name=request.node.name, + ) + + if measure_perf: + assert timing_result is not None + avg_time_ms, peak_memory = timing_result + logger.info(f"\nParallel dims: {parallel_dims}, " + f"\nAvg Time (ms): {avg_time_ms:.2f}, " + f"Peak Memory (MB): {peak_memory / (1024**2):.2f}") + + if sequential_moe_result is None: + logger.info("Skipping correctness check as sequential result is None") + elif measure_perf: + logger.info("Skipping correctness check as timing is enabled") + else: + assert_params_equal(parallelized_model, sequential_moe_result) + + +# --------------------------------------------------------------------------- +# Few-experts tests: num_experts=2, triggers EFSDP Shard(1) mode +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def sequential_moe_result_few_experts( + skip_verify, + moe_inputs_few_experts, +) -> torch.nn.Module | None: + """Run Muon optimizer on sequential MoE model (2 experts) for baseline.""" + if skip_verify: + logger.info("Skipping verification tests as per user request") + return None + + model, grads = moe_inputs_few_experts + + result, _ = apply_muon_step_moe( + model=copy.deepcopy(model).cuda(), + parallel_dims=None, + grads=grads, + warmup_step=-1, + chunk_size=-1, + small_param_numel_threshold=-1, + ) + result = result.cpu() + + return result + + +@pytest.mark.parametrize("parallel_dims", [ + pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="ep+fsdp"), + pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="ep+hsdp"), +]) +@pytest.mark.parametrize("use_distributed_muon", [False]) +@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) +@pytest.mark.parametrize("chunk_size", CHUNK_SIZES) +@pytest.mark.parametrize("small_param_numel_threshold", + SMALL_PARAM_NUMEL_THRESHOLDS) +def test_parallel_muon_moe_few_experts( + request, + sequential_moe_result_few_experts: torch.nn.Module | None, + parallel_dims: ParallelDims, + use_distributed_muon: bool, + warmup_step: int, + chunk_size: int, + small_param_numel_threshold: int, + moe_inputs_few_experts: tuple[torch.nn.Module, list[torch.Tensor]], + measure_perf, + do_profile, +) -> None: + model, grads = moe_inputs_few_experts + + model = copy.deepcopy(model).cuda() + + parallelized_model = parallelize_llama4(model, parallel_dims) + + parallelized_model, timing_result = apply_muon_step_moe( + model=parallelized_model, + parallel_dims=parallel_dims, + grads=grads, + warmup_step=warmup_step, + chunk_size=chunk_size, + small_param_numel_threshold=small_param_numel_threshold, + use_distributed_muon=use_distributed_muon, + measure_perf=measure_perf, + do_profile=do_profile, + test_name=request.node.name, + ) + + if measure_perf: + assert timing_result is not None + avg_time_ms, peak_memory = timing_result + logger.info(f"\nParallel dims: {parallel_dims}, " + f"\nAvg Time (ms): {avg_time_ms:.2f}, " + f"Peak Memory (MB): {peak_memory / (1024**2):.2f}") + + if sequential_moe_result_few_experts is None: + logger.info("Skipping correctness check as sequential result is None") + elif measure_perf: + logger.info("Skipping correctness check as timing is enabled") + else: + assert_params_equal(parallelized_model, + sequential_moe_result_few_experts) diff --git a/test/utils.py b/test/utils.py index 494c09de1f3241a5ef5028e47f21d17c7342645a..3572d0e132cc6615651ec8c8289cdeaf5af61691 100644 --- a/test/utils.py +++ b/test/utils.py @@ -16,11 +16,15 @@ class ParallelDims: dp_replicate_degree: int dp_shard_degree: int tp_degree: int + ep_degree: int = 1 def __str__(self) -> str: - return (f"dp_replicate-{self.dp_replicate_degree}_" - f"dp_shard-{self.dp_shard_degree}_" - f"tp-{self.tp_degree}") + s = (f"dp_replicate-{self.dp_replicate_degree}_" + f"dp_shard-{self.dp_shard_degree}_" + f"tp-{self.tp_degree}") + if self.ep_degree > 1: + s += f"_ep-{self.ep_degree}" + return s def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: @@ -35,7 +39,7 @@ def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: world_size = dist.get_world_size() expected_devices = (parallel_dims.dp_replicate_degree * parallel_dims.dp_shard_degree * - parallel_dims.tp_degree) + parallel_dims.ep_degree * parallel_dims.tp_degree) if world_size < expected_devices: raise ValueError( f"Not enough devices: found {world_size}, " @@ -43,9 +47,9 @@ def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: degrees = [ parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, - parallel_dims.tp_degree + parallel_dims.ep_degree, parallel_dims.tp_degree ] - dim_names = ["dp_replicate", "dp_shard", "tp"] + dim_names = ["dp_replicate", "dp_shard", "ep", "tp"] mesh_shape = [] mesh_dim_names = [] @@ -154,6 +158,43 @@ def _apply_fsdp( model.reshard() +def parallelize_llama4(model: torch.nn.Module, + parallel_dims: ParallelDims) -> torch.nn.Module: + """Parallelize the torchtitan Llama4 MoE model using torchtitan's + ``parallelize_llama`` directly. + """ + from torchtitan.config import JobConfig + from torchtitan.distributed import ParallelDims as TTParallelDims + from torchtitan.models.llama4.infra.parallelize import parallelize_llama + + world_size = dist.get_world_size() + + # Map our simple ParallelDims to torchtitan's ParallelDims. + # In torchtitan, EP borrows from dp_shard. + tt_dp_shard = parallel_dims.dp_shard_degree * parallel_dims.ep_degree + + tt_dims = TTParallelDims( + dp_replicate=parallel_dims.dp_replicate_degree, + dp_shard=tt_dp_shard, + cp=1, + tp=parallel_dims.tp_degree, + pp=1, + ep=parallel_dims.ep_degree, + etp=1, + world_size=world_size, + ) + + # Minimal JobConfig with test-appropriate settings. + job_config = JobConfig() + job_config.training.mixed_precision_param = "float32" + job_config.activation_checkpoint.mode = "none" + job_config.compile.enable = False + job_config.parallelism.disable_loss_parallel = True + + parallelize_llama(model, tt_dims, job_config) + return model + + def parallelize_motif(model: torch.nn.Module, parallel_dims: ParallelDims) -> torch.nn.Module: """Parallelize the Motif model according to the given parallel dimensions. diff --git a/torch-ext/optimizer/adamw.py b/torch-ext/optimizer/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..a6125200cc3da0996f0f3344131a7c6de4ac5863 --- /dev/null +++ b/torch-ext/optimizer/adamw.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import cast + +import torch +from torch.distributed.tensor import DTensor + + +def fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: dict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = optimizer_state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + +def step_adamw(optimizer_state, group): + """Dispatch AdamW step, grouping parameters by type and placement. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + group: Parameter group dict. + """ + params = group["params"] + + # group params with its type and placement + placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for group_params in placement_to_params.values(): + step_adamw_params(optimizer_state, group_params, group) diff --git a/torch-ext/optimizer/async_utils.py b/torch-ext/optimizer/async_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a45c530ac9cad88e3555ec1047a6aa59f225347e --- /dev/null +++ b/torch-ext/optimizer/async_utils.py @@ -0,0 +1,77 @@ +import logging +from typing import Generator + +logger = logging.getLogger(__name__) + + +class _Task: + """Internal: wraps a generator, advances one yield at a time.""" + + def __init__(self, generator: Generator[None, None, None], index: int): + self._generator = generator + self._index = index + self._steps_completed = 0 + self.step() # run to first yield + + def step(self) -> bool: + try: + next(self._generator) + self._steps_completed += 1 + logger.debug("pipeline[%d] completed stage %d", self._index, + self._steps_completed) + return True + except StopIteration: + logger.debug("pipeline[%d] finished after %d stages", self._index, + self._steps_completed) + return False + + def close(self): + self._generator.close() + + +def run_pipeline( + pipelines: Generator[Generator[None, None, None], None, None], + max_concurrent: int, +) -> None: + """Run generator-based pipelines with bounded concurrency. + + Each pipeline is a generator that yields at stage boundaries. + The runtime interleaves pipelines so communication and computation + overlap across chunks. + """ + if max_concurrent <= 0: + raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") + + have_new = True + task_index = 0 + previous_tasks: list[_Task] = [] + + try: + while have_new or previous_tasks: + running_tasks: list[_Task] = [] + + # Admit one new pipeline per iteration (staggered admission). + # Admitting one at a time ensures that while chunk N does NS + # compute on the default stream, chunk N+1's NCCL all-to-all + # runs concurrently on the NCCL stream — creating real + # communication/computation overlap on the GPU. + if have_new and len(previous_tasks) < max_concurrent: + try: + gen = next(pipelines) + task = _Task(gen, task_index) + task_index += 1 + running_tasks.append(task) + except StopIteration: + have_new = False + + # Advance every previously-yielded task by one step. + for task in previous_tasks: + if task.step(): + running_tasks.append(task) + + previous_tasks = running_tasks + except BaseException: + # Clean up all in-flight generators to release GPU resources. + for task in previous_tasks: + task.close() + raise diff --git a/torch-ext/optimizer/core.py b/torch-ext/optimizer/core.py new file mode 100644 index 0000000000000000000000000000000000000000..8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409 --- /dev/null +++ b/torch-ext/optimizer/core.py @@ -0,0 +1,116 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor + + +@dataclass +class _muon_state: + worker_rank: int + process_group: ProcessGroup + rank_indices: dict[int, tuple] # local_rank -> per-dim indices + rank_numels: dict[int, int] # local_rank -> numel + name: str + qk_clip_state: torch.Tensor | None = None + + +def update_g(optimizer_state, p, g, group, momentum): + """Apply momentum update to gradient. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + p: Parameter tensor. + g: Gradient tensor. + group: Parameter group dict. + momentum: Momentum coefficient. + + Returns: + Momentum-updated gradient tensor. + """ + state = optimizer_state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + +def update_p(p, u, lr, adjusted_lr, weight_decay): + """Apply weight decay and orthogonalized update to parameter. + + Args: + p: Parameter (torch.nn.Parameter or DTensor). + u: Orthogonalized update tensor. + lr: Base learning rate. + adjusted_lr: Size-adjusted learning rate. + weight_decay: Weight decay coefficient. + """ + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + +def adjust_lr_for_muon(lr, param_shape): + """Scale learning rate based on parameter matrix dimensions. + + Args: + lr: Base learning rate. + param_shape: Shape of the parameter tensor. + + Returns: + Adjusted learning rate. + """ + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as described in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def default_is_muon(name, x, expert_keys=None): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + if any(key in name for key in skip_keys): + return False + effective_ndim = x.ndim + if expert_keys and any(key in name for key in expert_keys): + effective_ndim -= 1 + return effective_ndim >= 2 + + +def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): + if is_muon_func is None: + is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) + + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] diff --git a/torch-ext/optimizer/distributed/utils.py b/torch-ext/optimizer/distributed/utils.py index 6d5843506c13d9d31603b2b4e30c1c91d0baab28..75e2e1e8d66975fc9aea75d994de288216a5e9a4 100644 --- a/torch-ext/optimizer/distributed/utils.py +++ b/torch-ext/optimizer/distributed/utils.py @@ -7,22 +7,40 @@ from torch.distributed.tensor.placement_types import (Placement, Shard, _StridedShard) +def _is_shard(placement: Placement) -> bool: + """Check if a placement is a shard type (Shard or _StridedShard). + + In PyTorch 2.10+, _StridedShard no longer inherits from Shard, so + ``placement.is_shard()`` returns False for _StridedShard. This helper + handles both old and new hierarchies. + """ + return isinstance(placement, (Shard, _StridedShard)) + + def get_slices_of_dtensor( target: DTensor | torch.Tensor, local_rank: int, shard_mesh: DeviceMesh, shard_placements: tuple[Placement], -) -> tuple[slice]: +) -> tuple[slice | torch.Tensor, ...]: """ - Get the slice of local tensor for a given rank from a tensor. + Get per-dimension indices for a given rank's shard of the target tensor. + + Uses ``Shard.local_shard_size_and_offset`` and + ``_StridedShard.local_shard_size_and_offset`` for correct handling of + both contiguous and strided (non-contiguous) sharding. + Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + target (DTensor | torch.Tensor): The target tensor (for its shape). + local_rank (int): The local rank within the shard group. + shard_mesh (DeviceMesh): The shard mesh (only shard dimensions). shard_placements (tuple[Placement]): The shard placements. - """ - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + Returns: + A tuple of indices (one per tensor dim). Each element is either: + - A ``slice`` (for contiguous or unsharded dims) + - A 1-D ``torch.LongTensor`` of indices (for strided sharding) + """ # find the global rank of the local rank in the shard mesh rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] @@ -34,34 +52,75 @@ def get_slices_of_dtensor( assert len(rank_coords) == len(shard_placements) + # Track per-shard-dim indices. + # None means "not yet sharded on this dim". + dim_indices: dict[int, torch.Tensor] = {} + # Caution: Assuming replicate-to-shard of the shard mesh goes with # left-to-right sharding. This is ensured by the sorting logic of # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) + for mesh_dim_idx, (rank_coord, placement) in enumerate( + zip(rank_coords, shard_placements)): + assert _is_shard(placement) - num_ranks = shard_mesh.mesh.shape[i] + num_chunks = shard_mesh.mesh.shape[mesh_dim_idx] + shard_dim = placement.dim - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) + # Current effective size on this dim (may already be sub-sharded) + if shard_dim in dim_indices: + curr_size = len(dim_indices[shard_dim]) + else: + curr_size = target.size()[shard_dim] - if dim_size % num_ranks != 0: + if curr_size % num_chunks != 0: raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}. (shape: {target.shape})") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) + f"Dimension size {curr_size} is not divisible " + f"by number of ranks {num_chunks} for shard " + f"placement on dim {shard_dim}. (shape: {target.shape})") + + # Compute indices for this level of sharding + if isinstance(placement, _StridedShard): + _shard_size, offsets = _StridedShard.local_shard_size_and_offset( + placement, + curr_size, + num_chunks, + rank_coord, + return_first_offset=False) + new_indices = torch.tensor(offsets, dtype=torch.long) + else: + shard_size, offset = Shard.local_shard_size_and_offset( + curr_size, num_chunks, rank_coord) + new_indices = torch.arange(offset, + offset + shard_size, + dtype=torch.long) + + # Compose with previous indices on this dim + if shard_dim in dim_indices: + dim_indices[shard_dim] = dim_indices[shard_dim][new_indices] + else: + dim_indices[shard_dim] = new_indices - return tuple(slices) + # Build result tuple + result: list[slice | torch.Tensor] = [] + for d in range(len(target.size())): + if d not in dim_indices: + result.append(slice(None)) + else: + indices = dim_indices[d] + # Convert contiguous indices to slice for efficiency + if len(indices) > 0: + start = indices[0].item() + expected = torch.arange(start, + start + len(indices), + dtype=torch.long) + if torch.equal(indices, expected): + result.append(slice(start, start + len(indices))) + else: + result.append(indices) + else: + result.append(slice(0, 0)) + + return tuple(result) _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, @@ -71,105 +130,105 @@ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, def construct_shard_mesh( placements: tuple[Placement], mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() +) -> tuple[DeviceMesh, ProcessGroup, tuple[Placement, ...]]: + """Construct shard sub-mesh and ProcessGroup for all-to-all communication. - assert mesh.mesh.device.type == 'cpu' + Given a DTensor's placements and device mesh, extracts the "shard group" + — the set of ranks that together hold all shards of the same replica — + and creates a ProcessGroup for all-to-all among them. - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") + Steps: + 1. Sort placements: Replicate first, then Shard by (dim, granularity). + 2. Permute the mesh tensor to match the sorted order. + 3. Collapse Replicate dims → list of shard sub-meshes (one per replica). + 4. Create/retrieve a cached ProcessGroup for the current rank's sub-mesh. - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) + Example — 8 GPUs, mesh shape (2, 2, 2), + placements ``[Shard(0), Replicate, _StridedShard(0)]``:: - sorted_indices, sorted_placements = zip(*placements_with_index) + Step 1 — Sort: [Replicate, _StridedShard(0), Shard(0)] + Permutation: [1, 2, 0] - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) + Step 2 — Permute mesh dims by [1, 2, 0]: + Original: Permuted: + [[[0,1],[2,3]], [[[0,2],[1,3]], + [[4,5],[6,7]]] [[4,6],[5,7]]] - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + Step 3 — Unbind replicate dim (dim 0), giving 2 shard sub-meshes: + sub-mesh 0 = [[0,2],[1,3]] (replica group 0) + sub-mesh 1 = [[4,6],[5,7]] (replica group 1) + shard_placements = (_StridedShard(0), Shard(0)) - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + Step 4 — Rank 0 → ProcessGroup([0,1,4,5]) + Rank 2 → ProcessGroup([2,3,6,7]) + + Returns: + ``(shard_mesh, process_group, shard_placements)`` + """ + my_rank = dist.get_rank() + assert mesh.mesh.device.type == 'cpu' + + # -- Fast path: 1D all-shard mesh → reuse existing PG. ---------------- + # This avoids a non-collective dist.new_group() call, which would + # deadlock when only a subset of ranks call this function (e.g. expert + # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately). + if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]): + key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist()) + if key not in _ranks_to_dist_cache: + _ranks_to_dist_cache[key] = (mesh, mesh.get_group()) + return (*_ranks_to_dist_cache[key], tuple(placements)) + + mesh_tensor = mesh.mesh.clone() + + # -- Step 1: Sort placements (Replicate first, then Shard by dim). ------ + # _StridedShard comes BEFORE regular Shard on the same dim so that + # get_slices_of_dtensor applies the outer sharding first, matching + # DTensor's left-to-right (outer-to-inner) composition order. + def _sort_key(item): + index, placement = item + assert not placement.is_partial(), "Partial placement not supported" + if placement.is_replicate(): + return (-1, 0, index) + assert _is_shard(placement), f"Unsupported: {type(placement)}" + split = (-1 / placement.split_factor if isinstance( + placement, _StridedShard) else 0) + return (placement.dim, split, index) + + indexed = sorted(enumerate(placements), key=_sort_key) + perm, sorted_placements = zip(*indexed) + + # -- Step 2: Permute mesh to match sorted placement order. -------------- + sorted_mesh = mesh_tensor.permute(perm) + + # -- Step 3: Collapse replicate dims → list of shard sub-meshes. -------- + # E.g. mesh (2, 3, 4, 4) with [R, R, S(0), S(1)] → 6 sub-meshes of (4, 4) + num_rep = sum(1 for p in sorted_placements if p.is_replicate()) + if num_rep > 0: + if num_rep > 1: + sorted_mesh = sorted_mesh.flatten(0, num_rep - 1) shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) else: shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different + shard_placements = sorted_placements[num_rep:] assert len(shard_placements) == len(set(shard_placements)) - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, + # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. -- + # All ranks must call dist.new_group in the same order, even though each + # rank only joins one group. + def _cache_key(t: torch.Tensor) -> tuple: + return (*t.shape, *t.flatten().tolist()) + + my_key = None + for sm in shard_meshes: + key = _cache_key(sm) + if (my_rank == sm).any().item(): + assert my_key is None, "Rank appears in multiple shard groups" + my_key = key + if key not in _ranks_to_dist_cache: + pg = dist.new_group(sm.flatten().tolist()) + _ranks_to_dist_cache[key] = ( + DeviceMesh(device_type="cuda", mesh=sm), + pg, ) - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements + return (*_ranks_to_dist_cache[my_key], shard_placements) diff --git a/torch-ext/optimizer/matmul_transpose_triton.py b/torch-ext/optimizer/matmul_transpose_triton.py index 4565b2c4fd506a4218340d380d6c962b16774b1d..95414c6dcd6ec6cd52bf7aebafa260871aff27aa 100644 --- a/torch-ext/optimizer/matmul_transpose_triton.py +++ b/torch-ext/optimizer/matmul_transpose_triton.py @@ -119,10 +119,3 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) - - -def matmul_transpose(d_in): - M, _ = d_in.shape - d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) - matmul_transpose_assign(d_in, d_out) - return d_out diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index dbf25575f185ff379789482068e4ecf55b9455a9..1195ca7bf4c2b594b5459ec114b8a8f2e530ad66 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -1,536 +1,121 @@ import logging -import math import types from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.profiler import record_function + +from .adamw import step_adamw +from .async_utils import run_pipeline +from .core import (_muon_state, adjust_lr_for_muon, + get_default_muon_param_groups, update_g, update_p) +from .distributed.utils import (_is_shard, construct_shard_mesh, + get_slices_of_dtensor) +from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, + _zeropower_via_newtonschulz5) +from .pipeline import muon_chunk_pipeline +from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) +def _expand_expert_params(names, params, expert_keys): + """Expand expert params by splitting on dim 0 (expert dimension). - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n + Params whose name matches any key in ``expert_keys`` are treated as + expert-parallel tensors. Their outermost dimension is the expert + dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D + ``nn.Parameter`` views so that in-place updates propagate back to + the original storage. - assert inner_off == block - off += block + Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — + if they are expert params, their key must be added to ``expert_keys``. + The grad must already be set on each expert param (e.g. after momentum). -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. + For DTensor expert params, placements that shard on dim 0 (expert dim) + are consumed by the split. Non-dim-0 shard placements (e.g. TP) are + preserved: each 2D slice is wrapped as a DTensor on the corresponding + submesh so the parallel pipeline handles the TP communication. """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: + expanded_names = [] + expanded_params = [] + + for n, p in zip(names, params): + is_expert = expert_keys and any(key in n for key in expert_keys) + is_dtensor = isinstance(p.data, DTensor) + + if not is_expert: + assert p.data.ndim <= 2, ( + f"Param {n} has ndim={p.data.ndim} but does not match " + f"expert_keys={expert_keys}. If this is an expert param, " + f"add its key to expert_keys.") + expanded_names.append(n) + expanded_params.append(p) continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx + g = p.grad + assert g is not None, ( + f"Expert param {n} must have grad set before expansion") + + tp_mesh = None + tp_placements_2d = None + + if is_dtensor: + local_data = p.to_local() + local_grad = g.to_local() if isinstance(g, DTensor) else g + + # Find non-dim-0 shard placements (e.g. TP sharding). + # After splitting on dim 0, Shard(k) becomes Shard(k-1). + tp_dim_indices = [] + tp_placements_2d = [] + for i, pl in enumerate(p.placements): + if _is_shard(pl) and pl.dim != 0: + tp_dim_indices.append(i) + tp_placements_2d.append(Shard(pl.dim - 1)) + + if tp_dim_indices: + tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] + for i in tp_dim_indices) + if len(tp_dim_names) == 1: + tp_mesh = p.device_mesh[tp_dim_names[0]] + else: + tp_mesh = p.device_mesh[tp_dim_names] + else: + local_data = p.data + local_grad = g + + # Expand: split dim 0, reshape each slice to 2D. + num_local_experts = local_data.shape[0] + for i in range(num_local_experts): + slice_data = local_data[i] + slice_grad = local_grad[i] + + if tp_mesh is not None: + # Wrap as DTensor on TP submesh so the pipeline handles + # TP communication (gather/scatter across TP ranks). + dt_data = DTensor.from_local(slice_data, + device_mesh=tp_mesh, + placements=tp_placements_2d) + dt_grad = DTensor.from_local(slice_grad, + device_mesh=tp_mesh, + placements=tp_placements_2d) + expert_param = torch.nn.Parameter(dt_data, requires_grad=False) + expert_param.grad = dt_grad + else: + expert_param = torch.nn.Parameter(slice_data, + requires_grad=False) + expert_param.grad = slice_grad - return None, -1 + expanded_names.append(f"{n}[{i}]") + expanded_params.append(expert_param) + p.grad = None # allow expert grad storage to be freed after pipeline -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None + return expanded_names, expanded_params class Muon(torch.optim.Optimizer): @@ -554,7 +139,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. @@ -564,7 +149,7 @@ class Muon(torch.optim.Optimizer): - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed + - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { @@ -584,6 +169,13 @@ class Muon(torch.optim.Optimizer): use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + expert_keys: List of strings to identify expert-parallel parameters. + If any key appears in a parameter's name, its outermost + dimension is treated as the expert dimension and expanded + into per-expert 2D params for Muon. For example, + ``expert_keys=["experts"]`` matches any param whose name + contains "experts". 3D+ params not matched by any key + will raise an error. """ def __init__(self, @@ -597,16 +189,12 @@ class Muon(torch.optim.Optimizer): adamw_eps=1e-8, none_grad=True, debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, + clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536): + small_param_numel_threshold=65536, + expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -630,16 +218,18 @@ class Muon(torch.optim.Optimizer): super().__init__(params, defaults) - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() self.debug = debug - self.clip_config = clip_config + self.clip_config = clip_config if clip_config is not None else { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100, + } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.small_param_numel_threshold = small_param_numel_threshold + self.expert_keys = expert_keys def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -649,20 +239,6 @@ class Muon(torch.optim.Optimizer): return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -673,9 +249,6 @@ class Muon(torch.optim.Optimizer): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): @@ -694,8 +267,8 @@ class Muon(torch.optim.Optimizer): total_flops += flops if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) + logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", + total_flops / 1e12) paired = list(zip(names, params)) @@ -724,44 +297,54 @@ class Muon(torch.optim.Optimizer): worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + + # Precompute per-rank indices and numels for all-to-all. + rank_indices: dict[int, tuple] = {} + rank_numels: dict[int, int] = {} + for r in range(num_ranks): + indices = get_slices_of_dtensor(p, r, shard_mesh, + shard_placements) + rank_indices[r] = indices + numel = 1 + for idx, dim_size in zip(indices, p.shape): + if isinstance(idx, slice): + start, stop, step = idx.indices(dim_size) + numel *= max(0, (stop - start + (step - 1)) // step) + else: + numel *= len(idx) + rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, + rank_indices=rank_indices, + rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion + def base(self, names, params, group, lr, weight_decay, qk_logits): + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state.head_dim) def distributed_muon( self, @@ -770,20 +353,15 @@ class Muon(torch.optim.Optimizer): group: dict[str, Any], lr: float, weight_decay: float, - momentum: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ + # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): @@ -796,16 +374,16 @@ class Muon(torch.optim.Optimizer): u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), steps=group["ns_steps"]) - adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) - Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) + update_p(p_full, u_full, lr, adjusted_lr, weight_decay) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) + qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) - scales_full = self._compute_scales( + scales_full = compute_scales( p_full, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): ndims = len(p.device_mesh.mesh.shape) @@ -822,244 +400,53 @@ class Muon(torch.optim.Optimizer): p.copy_(p_sharded) - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): + def parallel(self, names, params, group, lr, weight_decay, qk_logits): """ Perform a parallel optimization step using Muon. - """ - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) + Parameters are chunked and each chunk is processed by a + :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` + interleaves multiple chunks so that communication and computation + overlap across chunks (the same overlap previously achieved by the + warmup + main-loop index scheduling). + """ - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g + # Momentum is already applied by _step_muon before this method. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) + # Compute local rank for this group's shard process group. + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) + ordered_params[0])].process_group) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError("chunk_size must be -1 or a positive integer.") - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return + def pipelines(): + for start in range(0, len(ordered_params), chunk_size): + chunk = ordered_params[start:start + chunk_size] + if chunk: + yield muon_chunk_pipeline( + params=chunk, + param_to_state=param_to_state, + rank=rank, + ns_steps=group["ns_steps"], + lr=lr, + weight_decay=weight_decay, + none_grad=group["none_grad"], + ) - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) + with record_function("muon::barrier"): + dist.barrier() + with record_function("muon::pipeline"): + run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] @@ -1068,6 +455,18 @@ class Muon(torch.optim.Optimizer): momentum = group["momentum"] names = group["names"] + # Apply momentum to all params before routing/expansion. + with record_function("muon::momentum"): + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + g = update_g(self.state, p, g, group, momentum) + p.grad = g + + # Expand expert params by splitting on dim 0. + names, params = _expand_expert_params(names, params, self.expert_keys) + param_dtensors = [] name_dtensors = [] @@ -1083,7 +482,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits) return @@ -1119,7 +517,6 @@ class Muon(torch.optim.Optimizer): # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] assert len(dtensors) == len(names) for p, n in zip(dtensors, names): @@ -1141,7 +538,6 @@ class Muon(torch.optim.Optimizer): group=group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1159,7 +555,6 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) @@ -1170,78 +565,9 @@ class Muon(torch.optim.Optimizer): group, lr=lr, weight_decay=weight_decay, - momentum=momentum, qk_logits=qk_logits, ) - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -1249,9 +575,9 @@ class Muon(torch.optim.Optimizer): Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None @@ -1263,6 +589,6 @@ class Muon(torch.optim.Optimizer): if group["use_muon"]: self._step_muon(group, qk_logits=qk_logits) else: - self._step_adamw(group) + step_adamw(self.state, group) return loss diff --git a/torch-ext/optimizer/newton_schulz.py b/torch-ext/optimizer/newton_schulz.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fed6e6d186242df1e7e6e89b4416e31eb6bc63 --- /dev/null +++ b/torch-ext/optimizer/newton_schulz.py @@ -0,0 +1,50 @@ +import torch + +from .matmul_transpose_triton import matmul_transpose_assign + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X diff --git a/torch-ext/optimizer/pipeline.py b/torch-ext/optimizer/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9241f6d4457e4a7eacc4129056eadef5aa6961f6 --- /dev/null +++ b/torch-ext/optimizer/pipeline.py @@ -0,0 +1,390 @@ +import logging +from typing import Generator + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +from .core import _muon_state, adjust_lr_for_muon, update_p +from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .qk_clip import compute_scales + +logger = logging.getLogger(__name__) + +# ====================================================================== +# Stage helpers +# ====================================================================== + + +def _launch_gather( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Allocate gather buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_gather``). + gathered_grads: ``{id(p): empty_tensor}`` for owned params, + ``None`` for non-owned. + recv_counts: Per-source-rank element counts. + """ + # Allocate gathered-grad buffers + gathered_grads: dict[int, torch.Tensor | None] = {} + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + gathered_grads[id(p)] = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + gathered_grads[id(p)] = None + + # Build send buffer + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = state.rank_numels[rank] + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in + per_dst), "At least one destination rank must receive a sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += state.rank_numels[src] + recv_counts[src] = total + + recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, gathered_grads, recv_counts + + +def _complete_gather( + recv_buf: torch.Tensor, + recv_counts: list[int], + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + param_to_state: dict[int, _muon_state], + rank: int, +) -> None: + """Reconstruct gathered grads from the recv buffer (in-place).""" + off = 0 + for src in range(len(recv_counts)): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + indices = state.rank_indices[src] + + shard_view = gathered_grads[id(p)][indices] + n = shard_view.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape(shard_view.shape) + gathered_grads[id(p)][indices] = sg + + inner_off += n + assert inner_off == block + off += block + + +def _compute_ns( + owned_params: list[DTensor], + gathered_grads: dict[int, torch.Tensor | None], + ns_steps: int, +) -> dict[int, torch.Tensor | None]: + """Run Newton-Schulz orthogonalization on owned parameters. + + Returns: + computed_us: ``{id(p): orthogonalized_update}`` for owned params. + """ + computed_us: dict[int, torch.Tensor | None] = {} + for p in owned_params: + u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + gathered_grads[id(p)] = None # free gathered grad + computed_us[id(p)] = u + return computed_us + + +def _launch_scatter( + params: list[DTensor], + owned_params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + num_ranks: int, + process_group: dist.ProcessGroup, + computed_us: dict[int, torch.Tensor | None], +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: + """Allocate scatter buffers, build send/recv, and launch async all-to-all. + + Returns: + work: Async operation handle. + recv_buf: Flat receive buffer (needed by ``_complete_scatter``). + scattered_us: ``{id(p): empty_local_tensor}`` for all params. + recv_counts: Per-source-rank element counts. + """ + # Allocate scattered-u buffers + scattered_us: dict[int, torch.Tensor] = {} + for p in params: + scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + + # Build send buffer (from computed_us on owner ranks) + per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + + assert computed_us[id(p)] is not None + u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + total_sent = 0 + for dst_rank in range(num_ranks): + indices = state.rank_indices[dst_rank] + su = u_full[indices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst_rank].append(su) + send_counts[dst_rank] += n + total_sent += n + + assert total_sent == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + per_dst_flat = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst_flat, dim=0) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Build recv buffer + recv_counts = [0] * num_ranks + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += state.rank_numels[rank] + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + # Launch async all-to-all + work = dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + async_op=True, + ) + + return work, recv_buf, scattered_us, recv_counts + + +def _complete_scatter( + recv_buf: torch.Tensor, + recv_counts: list[int], + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], +) -> None: + """Copy recv buffer into scattered_us (in-place).""" + off = 0 + for src in range(len(recv_counts)): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = state.rank_numels[rank] + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + scattered_us[id(p)].copy_(flat_local) + + inner_off += n + + assert inner_off == block + off += block + + +def _update_params( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + scattered_us: dict[int, torch.Tensor], + lr: float, + weight_decay: float, +) -> None: + """Apply weight decay, Muon update, and optional QK clipping.""" + for p in params: + state = param_to_state[id(p)] + u_dtensor = DTensor.from_local( + scattered_us[id(p)], + placements=p.placements, + device_mesh=p.device_mesh, + ) + + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + scales_full = compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = state.rank_indices[rank][0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + +# ====================================================================== +# Main generator – thin orchestrator that wires stages together. +# ====================================================================== + + +@torch.no_grad() +def muon_chunk_pipeline( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + ns_steps: int, + lr: float, + weight_decay: float, + none_grad: bool, +) -> Generator[None, None, None]: + """Process one chunk of parameters through the full Muon pipeline. + + Stages: gather -> compute (Newton-Schulz) -> scatter -> update. + + Each ``yield`` lets :func:`run_pipeline` interleave other chunks so + that communication and computation overlap across chunks. Async + communication is launched via ``async_op=True`` and completed after + the yield with ``work.wait()``. + + Overlap happens because :func:`run_pipeline` admits one new chunk + per iteration (staggered admission). While chunk *N* does NS + compute on the default CUDA stream, chunk *N+1*'s async all-to-all + runs concurrently on the NCCL stream — no separate ``comm_stream`` + is required. + + Yields exactly **2** times: + + 1. After launching async all-to-all gather. + 2. After launching async all-to-all scatter. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Stages 1-2: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + yield # --- YIELD 1: other chunks can launch their gather --- + + with record_function("muon::wait_gather"): + work.wait() + _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, + param_to_state, rank) + del recv_buf + + # Stage 3: Newton-Schulz orthogonalization. + with record_function("muon::newton_schulz"): + computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) + gathered_grads.clear() + + # Stages 4-5: launch async scatter. + with record_function("muon::launch_scatter"): + work, recv_buf, scattered_us, recv_counts = _launch_scatter( + params, owned_params, param_to_state, rank, num_ranks, + process_group, computed_us) + computed_us.clear() + + yield # --- YIELD 2: other chunks can launch their scatter --- + + with record_function("muon::wait_scatter"): + work.wait() + _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, + scattered_us) + del recv_buf + + # Stage 6: apply parameter updates. + with record_function("muon::update_params"): + _update_params(params, param_to_state, rank, scattered_us, lr, + weight_decay) + scattered_us.clear() diff --git a/torch-ext/optimizer/qk_clip.py b/torch-ext/optimizer/qk_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f7199afa361bfb011ebdd4ed84b03709aaee7 --- /dev/null +++ b/torch-ext/optimizer/qk_clip.py @@ -0,0 +1,129 @@ +import logging +import math +from dataclasses import dataclass + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +def get_qk_clip_info(clip_config, n, qk_logits): + """Extract QK clipping info for a named parameter. + + Args: + clip_config: QK clipping configuration dict (or None). + n: Parameter name string. + qk_logits: Dict mapping layer indices to logit tensors (or None). + + Returns: + QKClipInfo instance with clipping configuration for this parameter. + """ + if clip_config is None: + return None + + head_dim = clip_config.get('head_dim') + threshold = clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + +def compute_scales(p, qk_clip_state): + """Compute per-head scaling factors for QK clipping. + + Returns scales tensor if any head exceeds threshold, else None. + """ + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + +def qk_clip(p, scales, head_dim): + """Apply per-head scaling to a Q/K projection weight matrix.""" + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1))