| import os | |
| import torch | |
| import datetime | |
| import numpy as np | |
| import torch.distributed as dist | |
| def setup_distributed(): | |
| """Initialize distributed training""" | |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| gpu = int(os.environ['LOCAL_RANK']) | |
| elif 'SLURM_PROCID' in os.environ: | |
| rank = int(os.environ['SLURM_PROCID']) | |
| gpu = rank % torch.cuda.device_count() | |
| world_size = int(os.environ['SLURM_NTASKS']) | |
| else: | |
| print('Not using distributed mode') | |
| return False, 0, 1, 0 | |
| torch.cuda.set_device(gpu) | |
| dist.init_process_group( | |
| backend='nccl', | |
| init_method='env://', | |
| world_size=world_size, | |
| rank=rank, | |
| timeout=datetime.timedelta(minutes=30) | |
| ) | |
| dist.barrier() | |
| return True, rank, world_size, gpu | |
| def cleanup_distributed(): | |
| """Cleanup distributed training""" | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def set_seed(seed, rank=0): | |
| """Set random seed for reproducibility""" | |
| seed = seed + rank | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |