| import os |
| import subprocess |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def setup_distributed(backend="nccl", port=None): |
| """AdaHessian Optimizer |
| Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py |
| Originally licensed MIT, Copyright (c) 2020 Wei Li |
| """ |
| num_gpus = torch.cuda.device_count() |
|
|
| if "SLURM_JOB_ID" in os.environ: |
| rank = int(os.environ["SLURM_PROCID"]) |
| world_size = int(os.environ["SLURM_NTASKS"]) |
| node_list = os.environ["SLURM_NODELIST"] |
| addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") |
| |
| if port is not None: |
| os.environ["MASTER_PORT"] = str(port) |
| elif "MASTER_PORT" not in os.environ: |
| os.environ["MASTER_PORT"] = "10685" |
| if "MASTER_ADDR" not in os.environ: |
| os.environ["MASTER_ADDR"] = addr |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["LOCAL_RANK"] = str(rank % num_gpus) |
| os.environ["RANK"] = str(rank) |
| else: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
|
|
| torch.cuda.set_device(rank % num_gpus) |
|
|
| dist.init_process_group( |
| backend=backend, |
| world_size=world_size, |
| rank=rank, |
| ) |
| return rank, world_size |
|
|