File size: 759 Bytes
0a7036f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
def _configure_model(model, shard_fn, param_dtype, device):
"""
TODO
"""
model.eval().requires_grad_(False)
if dist.is_initialized():
dist.barrier()
if dist.is_initialized():
model = shard_fn(model)
else:
model.to(param_dtype)
model.to(device)
return model
def init_distributed(world_size, local_rank, rank):
torch.cuda.set_device(local_rank)
if world_size > 1:
dist.init_process_group(backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
|