SLIM-Brain / utils /ddp.py
OneMore1's picture
Upload 12 files
538668e verified
import os
import torch
import datetime
import numpy as np
import torch.distributed as dist
def setup_distributed():
"""Initialize distributed training"""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
world_size = int(os.environ['SLURM_NTASKS'])
else:
print('Not using distributed mode')
return False, 0, 1, 0
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank,
timeout=datetime.timedelta(minutes=30)
)
dist.barrier()
return True, rank, world_size, gpu
def cleanup_distributed():
"""Cleanup distributed training"""
if dist.is_initialized():
dist.destroy_process_group()
def set_seed(seed, rank=0):
"""Set random seed for reproducibility"""
seed = seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)