File size: 633 Bytes
6a51385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import os
import torch.distributed as dist


def setup_distributed(rank: int, world_size: int) -> None:
    """Setup distributed training."""
    if world_size > 1:
        os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost')
        os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355')
        dist.init_process_group(
            backend='nccl',
            rank=rank,
            world_size=world_size
        )


def cleanup_distributed() -> None:
    """Cleanup distributed training."""
    if dist.is_available() and dist.is_initialized():
        dist.destroy_process_group()