Kernels
optimizer / test /test_muon_moe.py
wyldecat's picture
Replace toy PP tests with real-model-based pipeline tests [skip-build]
67f7e11
import copy
import logging
import time
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from optimizer.muon import Muon, get_default_muon_param_groups
from torch.distributed.tensor import (DTensor, Replicate, Shard,
distribute_tensor)
from torch.profiler import ProfilerActivity, profile
from .utils import ParallelDims, assert_params_equal, parallelize_llama4
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def _apply_grads(model, grads):
"""Apply gradients to model parameters (with DTensor redistribute)."""
for grad, param in zip(grads, model.parameters()):
grad = grad.to(param.device)
if isinstance(param.data, DTensor):
unsharded_grad = DTensor.from_local(
grad,
device_mesh=param.data.device_mesh,
placements=[Replicate()] * param.data.device_mesh.ndim,
)
param.grad = unsharded_grad.redistribute(
device_mesh=param.data.device_mesh,
placements=param.data.placements)
else:
param.grad = grad
def _restore_grads(model, saved_grads):
"""Restore previously saved grads (no redistribute, just reassign)."""
for param, g in zip(model.parameters(), saved_grads):
param.grad = g
def apply_muon_step_moe(
model: torch.nn.Module,
parallel_dims: ParallelDims | None,
grads: list[torch.Tensor],
warmup_step: int,
chunk_size: int,
use_distributed_muon: bool = False,
measure_perf: bool = False,
do_profile: bool = False,
test_name: str | None = None,
) -> tuple[torch.nn.Module, tuple[float, float] | None]:
"""Apply a single Muon step to an MoE model (no QK clipping)."""
assert len(grads) == len(list(model.parameters()))
_apply_grads(model, grads)
params = get_default_muon_param_groups(model, expert_keys=["experts"])
optim = Muon(
params=params,
clip_config=None,
none_grad=False,
warmup_step=warmup_step,
chunk_size=chunk_size,
use_distributed_muon=use_distributed_muon,
expert_keys=["experts"],
)
# Save sharded grads for re-use before step clears 3D grads.
saved_grads = [p.grad for p in model.parameters()]
optim.step()
# Second step to exercise expert expand cache hot path.
_restore_grads(model, saved_grads)
optim.step()
timing_result: tuple[float, float] | None = None
if measure_perf:
# extra warm up
_restore_grads(model, saved_grads)
optim.step()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats()
start.record()
num_iters = 20
if do_profile:
context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True)
else:
context = nullcontext()
with context as prof:
for _i in range(num_iters):
_restore_grads(model, saved_grads)
optim.step()
end.record()
end.synchronize()
if prof is not None:
date = time.strftime("%Y%m%d_%H%M%S", time.localtime())
name = test_name or "trace_moe"
rank = dist.get_rank()
prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json")
peak_memory = torch.cuda.max_memory_allocated()
elapsed_time_ms = start.elapsed_time(end) / num_iters
timing_result = (elapsed_time_ms, peak_memory)
return model, timing_result
@pytest.fixture(scope="session")
def sequential_moe_result(
skip_verify,
moe_inputs,
) -> torch.nn.Module | None:
"""Run Muon optimizer on sequential MoE model for baseline."""
if skip_verify:
logger.info("Skipping verification tests as per user request")
return None
model, grads = moe_inputs
result, _ = apply_muon_step_moe(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
)
result = result.cpu()
return result
OVERLAP_STEPS = [5]
CHUNK_SIZES = [2]
@pytest.mark.parametrize(
"parallel_dims",
[
# --- No EP (non-expert only) ---
pytest.param(ParallelDims(8, 1, 1), id="dp8"),
pytest.param(ParallelDims(1, 8, 1), id="fsdp8"),
pytest.param(ParallelDims(2, 4, 1), id="hsdp2x4"),
# --- EP configs ---
# naming: fsdp{dp_shard}_ep{ep} where dp_shard = dp_shard_mod_ep * ep
# dp_shard_mod_ep (= expert FSDP) = dp_shard_degree in our ParallelDims
pytest.param(ParallelDims(1, 1, 1, ep_degree=8), id="fsdp8_ep8"),
pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"),
pytest.param(ParallelDims(1, 2, 1, ep_degree=4), id="fsdp8_ep4"),
pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"),
])
@pytest.mark.parametrize("use_distributed_muon", [False])
@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
@pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
def test_parallel_muon_moe(
request,
sequential_moe_result: torch.nn.Module | None,
parallel_dims: ParallelDims,
use_distributed_muon: bool,
warmup_step: int,
chunk_size: int,
moe_inputs: tuple[torch.nn.Module, list[torch.Tensor]],
measure_perf,
do_profile,
) -> None:
model, grads = moe_inputs
# Deepcopy the model to avoid in-place modification
model = copy.deepcopy(model).cuda()
parallelized_model = parallelize_llama4(model, parallel_dims)
parallelized_model, timing_result = apply_muon_step_moe(
model=parallelized_model,
parallel_dims=parallel_dims,
grads=grads,
warmup_step=warmup_step,
chunk_size=chunk_size,
use_distributed_muon=use_distributed_muon,
measure_perf=measure_perf,
do_profile=do_profile,
test_name=request.node.name,
)
if measure_perf:
assert timing_result is not None
avg_time_ms, peak_memory = timing_result
logger.info(f"\nParallel dims: {parallel_dims}, "
f"\nAvg Time (ms): {avg_time_ms:.2f}, "
f"Peak Memory (MB): {peak_memory / (1024**2):.2f}")
if sequential_moe_result is None:
logger.info("Skipping correctness check as sequential result is None")
elif measure_perf:
logger.info("Skipping correctness check as timing is enabled")
else:
assert_params_equal(parallelized_model, sequential_moe_result)
# ---------------------------------------------------------------------------
# Few-experts tests: num_experts=2, triggers EFSDP Shard(1) mode
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def sequential_moe_result_few_experts(
skip_verify,
moe_inputs_few_experts,
) -> torch.nn.Module | None:
"""Run Muon optimizer on sequential MoE model (2 experts) for baseline."""
if skip_verify:
logger.info("Skipping verification tests as per user request")
return None
model, grads = moe_inputs_few_experts
result, _ = apply_muon_step_moe(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
)
result = result.cpu()
return result
@pytest.mark.parametrize("parallel_dims", [
pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"),
pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"),
])
@pytest.mark.parametrize("use_distributed_muon", [False])
@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
@pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
def test_parallel_muon_moe_few_experts(
request,
sequential_moe_result_few_experts: torch.nn.Module | None,
parallel_dims: ParallelDims,
use_distributed_muon: bool,
warmup_step: int,
chunk_size: int,
moe_inputs_few_experts: tuple[torch.nn.Module, list[torch.Tensor]],
measure_perf,
do_profile,
) -> None:
model, grads = moe_inputs_few_experts
model = copy.deepcopy(model).cuda()
parallelized_model = parallelize_llama4(model, parallel_dims)
parallelized_model, timing_result = apply_muon_step_moe(
model=parallelized_model,
parallel_dims=parallel_dims,
grads=grads,
warmup_step=warmup_step,
chunk_size=chunk_size,
use_distributed_muon=use_distributed_muon,
measure_perf=measure_perf,
do_profile=do_profile,
test_name=request.node.name,
)
if measure_perf:
assert timing_result is not None
avg_time_ms, peak_memory = timing_result
logger.info(f"\nParallel dims: {parallel_dims}, "
f"\nAvg Time (ms): {avg_time_ms:.2f}, "
f"Peak Memory (MB): {peak_memory / (1024**2):.2f}")
if sequential_moe_result_few_experts is None:
logger.info("Skipping correctness check as sequential result is None")
elif measure_perf:
logger.info("Skipping correctness check as timing is enabled")
else:
assert_params_equal(parallelized_model,
sequential_moe_result_few_experts)
# ---------------------------------------------------------------------------
# Uneven shard test: mixed expert (3D plain) + non-expert (2D DTensor)
# with dimensions not evenly divisible by shard count.
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("uneven_dim", [
pytest.param(33, id="33"),
pytest.param(19, id="19"),
])
def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
"""Test MoE parallel Muon with uneven shard dimensions.
Mixes non-expert 2D DTensor params (uneven FSDP sharding, parallel
pipeline path) with expert 3D plain-tensor params (batched NS path).
Verifies the combination produces correct results vs sequential baseline.
"""
from optimizer.newton_schulz import set_ns_compile
rank = dist.get_rank()
world_size = dist.get_world_size()
mesh = dist.init_device_mesh("cuda", (world_size, ),
mesh_dim_names=("dp", ))
set_ns_compile(False)
torch.manual_seed(42)
other_dim = 64
num_experts = 4
muon_params = []
muon_names = []
full_params = []
full_grads = []
# 2D non-expert params with uneven dims → parallel pipeline
for i in range(2):
full = torch.randn(uneven_dim, other_dim, device="cuda")
full_params.append(full.clone())
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
g = torch.randn(uneven_dim, other_dim, device="cuda")
full_grads.append(g.clone())
p.grad = distribute_tensor(g, mesh, [Shard(0)])
muon_params.append(p)
muon_names.append(f"layers.{i}.weight")
# 3D expert params (plain tensors) → batched NS path
full = torch.randn(num_experts, uneven_dim, other_dim, device="cuda")
full_params.append(full.clone())
p = torch.nn.Parameter(full)
g = torch.randn(num_experts, uneven_dim, other_dim, device="cuda")
full_grads.append(g.clone())
p.grad = g
muon_params.append(p)
muon_names.append("layers.2.experts.w1.weight")
# --- Parallel path ---
param_groups_par = [{
"params": muon_params,
"names": muon_names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim_par = Muon(params=param_groups_par,
chunk_size=1,
warmup_step=0,
expert_keys=["experts"])
optim_par.step()
# --- Sequential baseline ---
seq_params = []
for fp in full_params:
p = torch.nn.Parameter(fp.clone())
seq_params.append(p)
for p, g in zip(seq_params, full_grads):
p.grad = g.clone()
param_groups_seq = [{
"params": seq_params,
"names": list(muon_names),
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
optim_seq.step()
# --- Compare ---
for i in range(len(muon_params)):
par_data = muon_params[i].data
if isinstance(par_data, DTensor):
par_data = par_data.full_tensor()
torch.testing.assert_close(par_data,
seq_params[i].data,
atol=0,
rtol=0)
set_ns_compile(True)
logger.info(
"test_parallel_muon_moe_uneven_shard (dim=%d) PASSED (rank %d)",
uneven_dim, rank)
def test_pp_dp_replicate_moe_no_deadlock(init_dist, moe_inputs):
"""PP regression test using real torchtitan Llama4 MoE model.
PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the Llama4 MoE
model (4 layers, 8 experts) across 2 pipeline stages following the
torchtitan pattern. Uses torchtitan's ``parallelize_llama`` for
realistic FSDP application (same function as real training).
Each stage independently runs Muon optimizer with expert_keys and
the result is verified against a sequential baseline (atol=0, rtol=0).
Without use_local_synchronization=True in construct_shard_mesh(),
different stages would deadlock on dist.new_group().
"""
from optimizer.distributed.utils import _ranks_to_dist_cache
from optimizer.newton_schulz import set_ns_compile
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims as TTParallelDims
from torchtitan.models.llama4.infra.parallelize import parallelize_llama
rank = dist.get_rank()
assert dist.get_world_size() == 8
set_ns_compile(False)
_ranks_to_dist_cache.clear()
model_orig, grads_orig = moe_inputs
# Build name→grad mapping from original model
grad_dict = {
name: grad
for (name, _), grad in zip(model_orig.named_parameters(), grads_orig)
}
# torchtitan ParallelDims with PP=2 (same as real training config)
tt_dims = TTParallelDims(
dp_replicate=2,
dp_shard=2,
cp=1,
tp=1,
pp=2,
ep=1,
etp=1,
world_size=8,
)
# Accessing world_mesh triggers build_mesh() (lazy init).
# All ranks participate in init_device_mesh (collective).
pp_rank = tt_dims.world_mesh.get_local_rank("pp")
job_config = JobConfig()
job_config.training.mixed_precision_param = "float32"
job_config.activation_checkpoint.mode = "none"
job_config.compile.enable = False
job_config.parallelism.disable_loss_parallel = True
# -- Helpers ----------------------------------------------------------
def _split_llama4(model):
"""Split Llama4 MoE model per PP stage (torchtitan pattern).
Stage 0: tok_embeddings + layers["0"], ["1"]
Stage 1: layers["2"], ["3"] + norm + output
ModuleDict preserves keys → param names unchanged.
torchtitan model natively supports None modules in forward().
"""
if pp_rank == 0:
for key in ["2", "3"]:
if key in model.layers:
del model.layers[key]
model.norm = None
model.output = None
else:
for key in ["0", "1"]:
if key in model.layers:
del model.layers[key]
model.tok_embeddings = None
return model
def _stage_grads(model):
"""Build grads list aligned with stage model parameters."""
return [grad_dict[n] for n, _ in model.named_parameters()]
# -- Parallel path: split → parallelize_llama → Muon step -------------
par_model = _split_llama4(copy.deepcopy(model_orig).cuda())
parallelize_llama(par_model, tt_dims, job_config)
par_model, _ = apply_muon_step_moe(
model=par_model,
parallel_dims=None,
grads=_stage_grads(par_model),
warmup_step=5,
chunk_size=2,
)
# -- Sequential baseline: split → no parallelization → base Muon ------
seq_model = _split_llama4(copy.deepcopy(model_orig).cuda())
seq_model, _ = apply_muon_step_moe(
model=seq_model,
parallel_dims=None,
grads=_stage_grads(seq_model),
warmup_step=-1,
chunk_size=-1,
)
# Correctness: parallel must match sequential exactly
assert_params_equal(par_model, seq_model, atol=0, rtol=0)
set_ns_compile(True)
logger.info(
"test_pp_dp_replicate_moe_no_deadlock PASSED (rank %d, pp_rank %d)",
rank, pp_rank)