| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import torch | |
| import torch.distributed as dist | |
| def _configure_model(model, shard_fn, param_dtype, device): | |
| """ | |
| TODO | |
| """ | |
| model.eval().requires_grad_(False) | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if dist.is_initialized(): | |
| model = shard_fn(model) | |
| else: | |
| model.to(param_dtype) | |
| model.to(device) | |
| return model | |
| def init_distributed(world_size, local_rank, rank): | |
| torch.cuda.set_device(local_rank) | |
| if world_size > 1: | |
| dist.init_process_group(backend="nccl", | |
| init_method="env://", | |
| rank=rank, | |
| world_size=world_size) | |