""" Helpers for distributed training. """ import io import os import socket import blobfile as bf # from mpi4py import MPI import torch import torch.distributed as dist # Change this to reflect your cluster layout. # The GPU for a given rank is (rank % GPUS_PER_NODE). GPUS_PER_NODE = 8 SETUP_RETRY_COUNT = 3 def setup_dist(): """ Setup a distributed process group. """ if dist.is_initialized(): return #os.environ["CUDA_VISIBLE_DEVICES"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "0" backend = "gloo" if not torch.cuda.is_available() else "nccl" if backend == "gloo": hostname = "localhost" else: hostname = socket.gethostbyname(socket.getfqdn()) os.environ["MASTER_ADDR"] = "127.0.1.1" # comm.bcast(hostname, root=0) os.environ["RANK"] = "0" # str(comm.rank) os.environ["WORLD_SIZE"] = "1" # str(comm.size) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) s.listen(1) port = s.getsockname()[1] s.close() os.environ["MASTER_PORT"] = str(port) dist.init_process_group(backend=backend, init_method="env://") def dev(): """ Get the device to use for torch.distributed. """ if torch.cuda.is_available(): return torch.device(f"cuda") return torch.device("cpu") def load_state_dict(path, **kwargs): """ Load a PyTorch file without redundant fetches across MPI ranks. """ mpigetrank = 0 if mpigetrank == 0: with bf.BlobFile(path, "rb") as f: data = f.read() else: data = None return torch.load(io.BytesIO(data), **kwargs) def sync_params(params): """ Synchronize a sequence of Tensors across ranks from rank 0. """ for p in params: with torch.no_grad(): dist.broadcast(p, 0) def _find_free_port(): try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] finally: s.close()