import os import torch import datetime import numpy as np import torch.distributed as dist def setup_distributed(): """Initialize distributed training""" if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: rank = int(os.environ['SLURM_PROCID']) gpu = rank % torch.cuda.device_count() world_size = int(os.environ['SLURM_NTASKS']) else: print('Not using distributed mode') return False, 0, 1, 0 torch.cuda.set_device(gpu) dist.init_process_group( backend='nccl', init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(minutes=30) ) dist.barrier() return True, rank, world_size, gpu def cleanup_distributed(): """Cleanup distributed training""" if dist.is_initialized(): dist.destroy_process_group() def set_seed(seed, rank=0): """Set random seed for reproducibility""" seed = seed + rank torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed)