bazaar-research's picture
Upload folder using huggingface_hub
0a7036f verified
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
def _configure_model(model, shard_fn, param_dtype, device):
"""
TODO
"""
model.eval().requires_grad_(False)
if dist.is_initialized():
dist.barrier()
if dist.is_initialized():
model = shard_fn(model)
else:
model.to(param_dtype)
model.to(device)
return model
def init_distributed(world_size, local_rank, rank):
torch.cuda.set_device(local_rank)
if world_size > 1:
dist.init_process_group(backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)