| import logging |
|
|
| import torch |
| import torch.distributed as dist |
| from muon import Muon |
| from torch.distributed.fsdp import FSDPModule, fully_shard |
| from torch.distributed.tensor import DTensor |
| from torch.distributed.tensor.placement_types import Replicate |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, |
| PreTrainedTokenizerBase) |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def load_model(fsdp: bool) -> torch.nn.Module: |
| model_name = "Motif-Technologies/Motif-2.6B" |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| ).bfloat16().cuda() |
|
|
| torch.manual_seed(0) |
| random_grads = [] |
| for param in model.parameters(): |
| random_grad = torch.randn_like(param, |
| device=param.device, |
| dtype=param.dtype) |
| random_grads.append(random_grad) |
|
|
| if fsdp: |
| for layer in model.model.layers: |
| fully_shard(layer) |
| layer.reshard() |
| fully_shard(model) |
| model.reshard() |
|
|
| for i, param in enumerate(model.parameters()): |
| if isinstance(param.data, DTensor): |
| unsharded_grad = DTensor.from_local( |
| random_grads[i], |
| device_mesh=param.data.device_mesh, |
| placements=[Replicate()] * param.data.device_mesh.ndim, |
| ) |
| sharded_grad = unsharded_grad.redistribute( |
| device_mesh=param.data.device_mesh, |
| placements=param.data.placements) |
| param.grad = sharded_grad |
| else: |
| param.grad = random_grads[i] |
|
|
| return model |
|
|
|
|
| def run_muon(fsdp: bool) -> torch.nn.Module: |
| model = load_model(fsdp=fsdp) |
| optim = Muon(model) |
| optim.step() |
|
|
| return model |
|
|
|
|
| def compare_results(parallel_muon_result: torch.nn.Module, |
| sequential_muon_result: torch.nn.Module) -> None: |
| for (name_p, p), (name_s, |
| s) in zip(parallel_muon_result.named_parameters(), |
| sequential_muon_result.named_parameters()): |
| p = p.data.full_tensor() |
| s = s.data |
| |
| if torch.abs(p - s).max() > 0: |
| max_diff_index = torch.argmax(torch.abs(p - s)) |
| logger.error(f"Models differ at parameter {name_p}") |
| return |
| logger.info("Models match!") |
|
|
|
|
| def test_muon(): |
| parallel_muon_result = run_muon(fsdp=True) |
| sequential_muon_result = run_muon(fsdp=False) |
|
|
| compare_results(parallel_muon_result, sequential_muon_result) |
|
|
|
|
| if __name__ == "__main__": |
| dist.init_process_group(backend="nccl") |
| torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
| test_muon() |
|
|