# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.distributed as dist def _configure_model(model, shard_fn, param_dtype, device): """ TODO """ model.eval().requires_grad_(False) if dist.is_initialized(): dist.barrier() if dist.is_initialized(): model = shard_fn(model) else: model.to(param_dtype) model.to(device) return model def init_distributed(world_size, local_rank, rank): torch.cuda.set_device(local_rank) if world_size > 1: dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)