| import copy |
| import logging |
| import time |
| from contextlib import nullcontext |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from optimizer.muon import Muon, get_default_muon_param_groups |
| from torch.distributed.tensor import DTensor, Replicate |
| from torch.profiler import ProfilerActivity, profile |
|
|
| from .utils import (ParallelDims, assert_params_equal, parallelize_motif, |
| parallelize_qk_logits) |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def apply_muon_step( |
| model: torch.nn.Module, |
| parallel_dims: ParallelDims | None, |
| grads: list[torch.Tensor], |
| warmup_step: int, |
| chunk_size: int, |
| qk_logits: dict[int, torch.Tensor] | None = None, |
| use_distributed_muon: bool = False, |
| measure_perf: bool = False, |
| do_profile: bool = False, |
| ) -> tuple[torch.nn.Module, tuple[float, float] | None]: |
| """ apply single Muon step with optional QK clipping """ |
|
|
| |
| assert len(grads) == len(list(model.parameters())) |
| for grad, param in zip(grads, model.parameters()): |
| grad = grad.to(param.device) |
| if isinstance(param.data, DTensor): |
| unsharded_grad = DTensor.from_local( |
| grad, |
| device_mesh=param.data.device_mesh, |
| placements=[Replicate()] * param.data.device_mesh.ndim, |
| ) |
| sharded_grad = unsharded_grad.redistribute( |
| device_mesh=param.data.device_mesh, |
| placements=param.data.placements) |
| param.grad = sharded_grad |
| else: |
| param.grad = grad |
|
|
| |
| params = get_default_muon_param_groups(model) |
| clip_config = dict({ |
| "q_indices": |
| list(range(model.config.num_attention_heads)), |
| "k_indices": |
| list(range(model.config.num_attention_heads)), |
| "head_dim": |
| model.config.hidden_size // model.config.num_attention_heads, |
| "threshold": |
| 0.5 |
| }) |
| optim = Muon( |
| params=params, |
| clip_config=clip_config if qk_logits is not None else None, |
| none_grad=False, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| use_distributed_muon=use_distributed_muon, |
| ) |
|
|
| optim.step(qk_logits=qk_logits) |
|
|
| timing_result: tuple[float, float] | None = None |
|
|
| if measure_perf: |
| |
| optim.step(qk_logits=qk_logits) |
|
|
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
|
|
| start.record() |
| num_iters = 20 |
| current_mem = torch.cuda.memory_allocated() |
|
|
| if do_profile: |
| context = profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| record_shapes=True) |
| else: |
| context = nullcontext() |
|
|
| with context as prof: |
| for _i in range(num_iters): |
| optim.step(qk_logits=qk_logits) |
|
|
| end.record() |
| end.synchronize() |
|
|
| if prof is not None and dist.get_rank() == 0: |
| date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) |
| profile_name = "trace" |
| profile_name += f"_{date}" |
| profile_name += f"_{parallel_dims}" |
| profile_name += f"_{chunk_size}" |
| profile_name += f"_{warmup_step}" |
| profile_name += f"_{qk_logits is not None}" |
| profile_name += f"_{use_distributed_muon}" |
|
|
| prof.export_chrome_trace(f"{profile_name}.json") |
|
|
| peak_memory = torch.cuda.max_memory_allocated() - current_mem |
|
|
| elapsed_time_ms = start.elapsed_time(end) / num_iters |
|
|
| timing_result = (elapsed_time_ms, peak_memory) |
|
|
| return model, timing_result |
|
|
|
|
| @pytest.fixture(scope="session") |
| def sequential_muon_result( |
| skip_verify, |
| inputs |
| ) -> dict[bool, torch.nn.Module]: |
| """Run Muon optimizer to sequential model for baseline results.""" |
| if skip_verify: |
| logger.info("Skipping verification tests as per user request") |
| return None |
|
|
| model, grads, qk_logits = inputs |
|
|
| result = apply_muon_step( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| qk_logits=None, |
| )[0].cpu() |
|
|
| result_qk_clip = apply_muon_step( |
| model=copy.deepcopy(model).cuda(), |
| parallel_dims=None, |
| grads=grads, |
| warmup_step=-1, |
| chunk_size=-1, |
| qk_logits=qk_logits, |
| )[0].cpu() |
|
|
| return { |
| False: result, |
| True: result_qk_clip, |
| } |
|
|
|
|
| OVERLAP_STEPS = [5] |
| CHUNK_SIZES = [8] |
|
|
|
|
| @pytest.mark.parametrize("parallel_dims", [ |
| pytest.param(ParallelDims(8, 1, 1), id="base"), |
| pytest.param(ParallelDims(1, 8, 1), id="fsdp"), |
| pytest.param(ParallelDims(2, 4, 1), id="hsdp"), |
| pytest.param(ParallelDims(1, 1, 8), id="tp"), |
| pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), |
| pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), |
| ]) |
| @pytest.mark.parametrize("apply_qk_clip", [False, True]) |
| @pytest.mark.parametrize("use_distributed_muon", [False]) |
| @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) |
| @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) |
| def test_parallel_muon( |
| request, |
| sequential_muon_result: dict[bool, torch.nn.Module], |
| parallel_dims: ParallelDims, |
| apply_qk_clip: bool, |
| use_distributed_muon: bool, |
| warmup_step: int, |
| chunk_size: int, |
| inputs: tuple[torch.nn.Module, list[torch.Tensor], |
| dict[int, torch.Tensor]], |
| measure_perf, |
| do_profile, |
| ) -> None: |
| if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: |
| pytest.skip("Distributed Muon does not effected by chunk size") |
| if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: |
| pytest.skip("Distributed Muon does not effected by warmup step") |
|
|
| model, grads, qk_logits = inputs |
|
|
| if not apply_qk_clip: |
| qk_logits = None |
|
|
| |
| model = copy.deepcopy(model).cuda() |
|
|
| parallelized_model = parallelize_motif(model, parallel_dims) |
|
|
| if qk_logits is not None: |
| |
| qk_logits = copy.deepcopy(qk_logits) |
| qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) |
|
|
| parallelized_model, timing_result = apply_muon_step( |
| model=parallelized_model, |
| parallel_dims=parallel_dims, |
| grads=grads, |
| warmup_step=warmup_step, |
| chunk_size=chunk_size, |
| qk_logits=qk_logits, |
| use_distributed_muon=use_distributed_muon, |
| measure_perf=measure_perf, |
| do_profile=do_profile, |
| ) |
|
|
| if measure_perf: |
| assert timing_result is not None |
| avg_time_ms, peak_memory = timing_result |
| logger.info( |
| f"\nParallel dims: {parallel_dims}, " |
| f"\nUse distributed Muon: {use_distributed_muon}, " |
| f"\nApply QK clip: {apply_qk_clip} => " |
| f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" |
| f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," |
| ) |
|
|
| if sequential_muon_result is None: |
| logger.info("Skipping correctness check as sequential result is None") |
| elif measure_perf: |
| logger.info("Skipping correctness check as timing is enabled") |
| else: |
| assert_params_equal(parallelized_model, |
| sequential_muon_result[apply_qk_clip]) |
|
|