File size: 633 Bytes
6a51385 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import os
import torch.distributed as dist
def setup_distributed(rank: int, world_size: int) -> None:
"""Setup distributed training."""
if world_size > 1:
os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost')
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355')
dist.init_process_group(
backend='nccl',
rank=rank,
world_size=world_size
)
def cleanup_distributed() -> None:
"""Cleanup distributed training."""
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group() |