File size: 1,298 Bytes
538668e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import os
import torch
import datetime
import numpy as np
import torch.distributed as dist
def setup_distributed():
"""Initialize distributed training"""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
world_size = int(os.environ['SLURM_NTASKS'])
else:
print('Not using distributed mode')
return False, 0, 1, 0
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank,
timeout=datetime.timedelta(minutes=30)
)
dist.barrier()
return True, rank, world_size, gpu
def cleanup_distributed():
"""Cleanup distributed training"""
if dist.is_initialized():
dist.destroy_process_group()
def set_seed(seed, rank=0):
"""Set random seed for reproducibility"""
seed = seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
|