import os import torch.distributed as dist def setup_distributed(rank: int, world_size: int) -> None: """Setup distributed training.""" if world_size > 1: os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') dist.init_process_group( backend='nccl', rank=rank, world_size=world_size ) def cleanup_distributed() -> None: """Cleanup distributed training.""" if dist.is_available() and dist.is_initialized(): dist.destroy_process_group()