| import copy |
| import logging |
| import time |
| from contextlib import nullcontext |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from optimizer.muon import Muon, get_default_muon_param_groups |
| from torch.distributed.tensor import (DTensor, Replicate, Shard, |
| distribute_tensor) |
| from torch.profiler import ProfilerActivity, profile |
|
|
| from .utils import ParallelDims, assert_params_equal, parallelize_llama4 |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def _apply_grads(model, grads): |
| """Apply gradients to model parameters (with DTensor redistribute).""" |
| for grad, param in zip(grads, model.parameters()): |
| grad = grad.to(param.device) |
| if isinstance(param.data, DTensor): |
| unsharded_grad = DTensor.from_local( |
| grad, |
| device_mesh=param.data.device_mesh, |
| placements=[Replicate()] * param.data.device_mesh.ndim, |
| ) |
| param.grad = unsharded_grad.redistribute( |
| device_mesh=param.data.device_mesh, |
| placements=param.data.placements) |
| else: |
| param.grad = grad |
|
|
|
|
| def _restore_grads(model, saved_grads): |
| """Restore previously saved grads (no redistribute, just reassign).""" |
| for param, g in zip(model.parameters(), saved_grads): |
| param.grad = g |
|
|
|
|
| def apply_muon_step_moe( |
| model: torch.nn.Module, |
| parallel_dims: ParallelDims | None, |
| grads: list[torch.Tensor], |
| warmup_step: int, |
| chunk_size: int, |
| use_distributed_muon: bool = False, |
| measure_perf: bool = False, |
| do_profile: bool = False, |
| test_name: str | None = None, |
| ) -> tuple[torch.nn.Module, tuple[float, float] | None]: |
| """Apply a single Muon step to an MoE model (no QK clipping).""" |
|
|
| assert len(grads) == len(list(model.parameters())) |
| _apply_grads(model, grads) |
|
|
| params = get_default_muon_param_groups(model, expert_keys=["experts"]) |
| optim = Muon( |
| params=params, |
| clip_config=None, |
| none_grad=False, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| use_distributed_muon=use_distributed_muon, |
| expert_keys=["experts"], |
| ) |
|
|
| |
| saved_grads = [p.grad for p in model.parameters()] |
|
|
| optim.step() |
|
|
| |
| _restore_grads(model, saved_grads) |
| optim.step() |
|
|
| timing_result: tuple[float, float] | None = None |
|
|
| if measure_perf: |
| |
| _restore_grads(model, saved_grads) |
| optim.step() |
|
|
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| start.record() |
| num_iters = 20 |
|
|
| if do_profile: |
| context = profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| record_shapes=True) |
| else: |
| context = nullcontext() |
|
|
| with context as prof: |
| for _i in range(num_iters): |
| _restore_grads(model, saved_grads) |
| optim.step() |
|
|
| end.record() |
| end.synchronize() |
|
|
| if prof is not None: |
| date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) |
| name = test_name or "trace_moe" |
| rank = dist.get_rank() |
| prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") |
|
|
| peak_memory = torch.cuda.max_memory_allocated() |
| elapsed_time_ms = start.elapsed_time(end) / num_iters |
| timing_result = (elapsed_time_ms, peak_memory) |
|
|
| return model, timing_result |
|
|
|
|
| @pytest.fixture(scope="session") |
| def sequential_moe_result( |
| skip_verify, |
| moe_inputs, |
| ) -> torch.nn.Module | None: |
| """Run Muon optimizer on sequential MoE model for baseline.""" |
| if skip_verify: |
| logger.info("Skipping verification tests as per user request") |
| return None |
|
|
| model, grads = moe_inputs |
|
|
| result, _ = apply_muon_step_moe( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| ) |
| result = result.cpu() |
|
|
| return result |
|
|
|
|
| OVERLAP_STEPS = [5] |
| CHUNK_SIZES = [2] |
|
|
|
|
| @pytest.mark.parametrize( |
| "parallel_dims", |
| [ |
| |
| pytest.param(ParallelDims(8, 1, 1), id="dp8"), |
| pytest.param(ParallelDims(1, 8, 1), id="fsdp8"), |
| pytest.param(ParallelDims(2, 4, 1), id="hsdp2x4"), |
| |
| |
| |
| pytest.param(ParallelDims(1, 1, 1, ep_degree=8), id="fsdp8_ep8"), |
| pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"), |
| pytest.param(ParallelDims(1, 2, 1, ep_degree=4), id="fsdp8_ep4"), |
| pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"), |
| ]) |
| @pytest.mark.parametrize("use_distributed_muon", [False]) |
| @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) |
| @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) |
| def test_parallel_muon_moe( |
| request, |
| sequential_moe_result: torch.nn.Module | None, |
| parallel_dims: ParallelDims, |
| use_distributed_muon: bool, |
| warmup_step: int, |
| chunk_size: int, |
| moe_inputs: tuple[torch.nn.Module, list[torch.Tensor]], |
| measure_perf, |
| do_profile, |
| ) -> None: |
| model, grads = moe_inputs |
|
|
| |
| model = copy.deepcopy(model).cuda() |
|
|
| parallelized_model = parallelize_llama4(model, parallel_dims) |
|
|
| parallelized_model, timing_result = apply_muon_step_moe( |
| model=parallelized_model, |
| parallel_dims=parallel_dims, |
| grads=grads, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| use_distributed_muon=use_distributed_muon, |
| measure_perf=measure_perf, |
| do_profile=do_profile, |
| test_name=request.node.name, |
| ) |
|
|
| if measure_perf: |
| assert timing_result is not None |
| avg_time_ms, peak_memory = timing_result |
| logger.info(f"\nParallel dims: {parallel_dims}, " |
| f"\nAvg Time (ms): {avg_time_ms:.2f}, " |
| f"Peak Memory (MB): {peak_memory / (1024**2):.2f}") |
|
|
| if sequential_moe_result is None: |
| logger.info("Skipping correctness check as sequential result is None") |
| elif measure_perf: |
| logger.info("Skipping correctness check as timing is enabled") |
| else: |
| assert_params_equal(parallelized_model, sequential_moe_result) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.fixture(scope="session") |
| def sequential_moe_result_few_experts( |
| skip_verify, |
| moe_inputs_few_experts, |
| ) -> torch.nn.Module | None: |
| """Run Muon optimizer on sequential MoE model (2 experts) for baseline.""" |
| if skip_verify: |
| logger.info("Skipping verification tests as per user request") |
| return None |
|
|
| model, grads = moe_inputs_few_experts |
|
|
| result, _ = apply_muon_step_moe( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| ) |
| result = result.cpu() |
|
|
| return result |
|
|
|
|
| @pytest.mark.parametrize("parallel_dims", [ |
| pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"), |
| pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"), |
| ]) |
| @pytest.mark.parametrize("use_distributed_muon", [False]) |
| @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) |
| @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) |
| def test_parallel_muon_moe_few_experts( |
| request, |
| sequential_moe_result_few_experts: torch.nn.Module | None, |
| parallel_dims: ParallelDims, |
| use_distributed_muon: bool, |
| warmup_step: int, |
| chunk_size: int, |
| moe_inputs_few_experts: tuple[torch.nn.Module, list[torch.Tensor]], |
| measure_perf, |
| do_profile, |
| ) -> None: |
| model, grads = moe_inputs_few_experts |
|
|
| model = copy.deepcopy(model).cuda() |
|
|
| parallelized_model = parallelize_llama4(model, parallel_dims) |
|
|
| parallelized_model, timing_result = apply_muon_step_moe( |
| model=parallelized_model, |
| parallel_dims=parallel_dims, |
| grads=grads, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| use_distributed_muon=use_distributed_muon, |
| measure_perf=measure_perf, |
| do_profile=do_profile, |
| test_name=request.node.name, |
| ) |
|
|
| if measure_perf: |
| assert timing_result is not None |
| avg_time_ms, peak_memory = timing_result |
| logger.info(f"\nParallel dims: {parallel_dims}, " |
| f"\nAvg Time (ms): {avg_time_ms:.2f}, " |
| f"Peak Memory (MB): {peak_memory / (1024**2):.2f}") |
|
|
| if sequential_moe_result_few_experts is None: |
| logger.info("Skipping correctness check as sequential result is None") |
| elif measure_perf: |
| logger.info("Skipping correctness check as timing is enabled") |
| else: |
| assert_params_equal(parallelized_model, |
| sequential_moe_result_few_experts) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| @pytest.mark.parametrize("uneven_dim", [ |
| pytest.param(33, id="33"), |
| pytest.param(19, id="19"), |
| ]) |
| def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim): |
| """Test MoE parallel Muon with uneven shard dimensions. |
| |
| Mixes non-expert 2D DTensor params (uneven FSDP sharding, parallel |
| pipeline path) with expert 3D plain-tensor params (batched NS path). |
| Verifies the combination produces correct results vs sequential baseline. |
| """ |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| mesh = dist.init_device_mesh("cuda", (world_size, ), |
| mesh_dim_names=("dp", )) |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| other_dim = 64 |
| num_experts = 4 |
|
|
| muon_params = [] |
| muon_names = [] |
| full_params = [] |
| full_grads = [] |
|
|
| |
| for i in range(2): |
| full = torch.randn(uneven_dim, other_dim, device="cuda") |
| full_params.append(full.clone()) |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| g = torch.randn(uneven_dim, other_dim, device="cuda") |
| full_grads.append(g.clone()) |
| p.grad = distribute_tensor(g, mesh, [Shard(0)]) |
| muon_params.append(p) |
| muon_names.append(f"layers.{i}.weight") |
|
|
| |
| full = torch.randn(num_experts, uneven_dim, other_dim, device="cuda") |
| full_params.append(full.clone()) |
| p = torch.nn.Parameter(full) |
| g = torch.randn(num_experts, uneven_dim, other_dim, device="cuda") |
| full_grads.append(g.clone()) |
| p.grad = g |
| muon_params.append(p) |
| muon_names.append("layers.2.experts.w1.weight") |
|
|
| |
| param_groups_par = [{ |
| "params": muon_params, |
| "names": muon_names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
| optim_par = Muon(params=param_groups_par, |
| chunk_size=1, |
| warmup_step=0, |
| expert_keys=["experts"]) |
| optim_par.step() |
|
|
| |
| seq_params = [] |
| for fp in full_params: |
| p = torch.nn.Parameter(fp.clone()) |
| seq_params.append(p) |
|
|
| for p, g in zip(seq_params, full_grads): |
| p.grad = g.clone() |
|
|
| param_groups_seq = [{ |
| "params": seq_params, |
| "names": list(muon_names), |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
| optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"]) |
| optim_seq.step() |
|
|
| |
| for i in range(len(muon_params)): |
| par_data = muon_params[i].data |
| if isinstance(par_data, DTensor): |
| par_data = par_data.full_tensor() |
| torch.testing.assert_close(par_data, |
| seq_params[i].data, |
| atol=0, |
| rtol=0) |
|
|
| set_ns_compile(True) |
| logger.info( |
| "test_parallel_muon_moe_uneven_shard (dim=%d) PASSED (rank %d)", |
| uneven_dim, rank) |
|
|
|
|
| def test_pp_dp_replicate_moe_no_deadlock(init_dist, moe_inputs): |
| """PP regression test using real torchtitan Llama4 MoE model. |
| |
| PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the Llama4 MoE |
| model (4 layers, 8 experts) across 2 pipeline stages following the |
| torchtitan pattern. Uses torchtitan's ``parallelize_llama`` for |
| realistic FSDP application (same function as real training). |
| |
| Each stage independently runs Muon optimizer with expert_keys and |
| the result is verified against a sequential baseline (atol=0, rtol=0). |
| |
| Without use_local_synchronization=True in construct_shard_mesh(), |
| different stages would deadlock on dist.new_group(). |
| """ |
| from optimizer.distributed.utils import _ranks_to_dist_cache |
| from optimizer.newton_schulz import set_ns_compile |
| from torchtitan.config import JobConfig |
| from torchtitan.distributed import ParallelDims as TTParallelDims |
| from torchtitan.models.llama4.infra.parallelize import parallelize_llama |
|
|
| rank = dist.get_rank() |
| assert dist.get_world_size() == 8 |
|
|
| set_ns_compile(False) |
| _ranks_to_dist_cache.clear() |
|
|
| model_orig, grads_orig = moe_inputs |
|
|
| |
| grad_dict = { |
| name: grad |
| for (name, _), grad in zip(model_orig.named_parameters(), grads_orig) |
| } |
|
|
| |
| tt_dims = TTParallelDims( |
| dp_replicate=2, |
| dp_shard=2, |
| cp=1, |
| tp=1, |
| pp=2, |
| ep=1, |
| etp=1, |
| world_size=8, |
| ) |
|
|
| |
| |
| pp_rank = tt_dims.world_mesh.get_local_rank("pp") |
|
|
| job_config = JobConfig() |
| job_config.training.mixed_precision_param = "float32" |
| job_config.activation_checkpoint.mode = "none" |
| job_config.compile.enable = False |
| job_config.parallelism.disable_loss_parallel = True |
|
|
| |
| def _split_llama4(model): |
| """Split Llama4 MoE model per PP stage (torchtitan pattern). |
| |
| Stage 0: tok_embeddings + layers["0"], ["1"] |
| Stage 1: layers["2"], ["3"] + norm + output |
| ModuleDict preserves keys → param names unchanged. |
| torchtitan model natively supports None modules in forward(). |
| """ |
| if pp_rank == 0: |
| for key in ["2", "3"]: |
| if key in model.layers: |
| del model.layers[key] |
| model.norm = None |
| model.output = None |
| else: |
| for key in ["0", "1"]: |
| if key in model.layers: |
| del model.layers[key] |
| model.tok_embeddings = None |
| return model |
|
|
| def _stage_grads(model): |
| """Build grads list aligned with stage model parameters.""" |
| return [grad_dict[n] for n, _ in model.named_parameters()] |
|
|
| |
| par_model = _split_llama4(copy.deepcopy(model_orig).cuda()) |
| parallelize_llama(par_model, tt_dims, job_config) |
|
|
| par_model, _ = apply_muon_step_moe( |
| model=par_model, |
| parallel_dims=None, |
| grads=_stage_grads(par_model), |
| warmup_step=5, |
| chunk_size=2, |
| ) |
|
|
| |
| seq_model = _split_llama4(copy.deepcopy(model_orig).cuda()) |
|
|
| seq_model, _ = apply_muon_step_moe( |
| model=seq_model, |
| parallel_dims=None, |
| grads=_stage_grads(seq_model), |
| warmup_step=-1, |
| chunk_size=-1, |
| ) |
|
|
| |
| assert_params_equal(par_model, seq_model, atol=0, rtol=0) |
|
|
| set_ns_compile(True) |
| logger.info( |
| "test_pp_dp_replicate_moe_no_deadlock PASSED (rank %d, pp_rank %d)", |
| rank, pp_rank) |
|
|