| | import os |
| | import logging |
| | from datetime import timedelta |
| | import torch |
| | import torch.distributed as dist |
| | import torch.multiprocessing as mp |
| |
|
| | from pointcept.utils import comm |
| |
|
| | __all__ = ["DEFAULT_TIMEOUT", "launch"] |
| |
|
| | DEFAULT_TIMEOUT = timedelta(minutes=60) |
| |
|
| |
|
| | def _find_free_port(): |
| | import socket |
| |
|
| | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| | |
| | sock.bind(("", 0)) |
| | port = sock.getsockname()[1] |
| | sock.close() |
| | |
| | return port |
| |
|
| |
|
| | def launch( |
| | main_func, |
| | num_gpus_per_machine, |
| | num_machines=1, |
| | machine_rank=0, |
| | dist_url=None, |
| | cfg=(), |
| | timeout=DEFAULT_TIMEOUT, |
| | ): |
| | """ |
| | Launch multi-gpu or distributed training. |
| | This function must be called on all machines involved in the training. |
| | It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. |
| | Args: |
| | main_func: a function that will be called by `main_func(*args)` |
| | num_gpus_per_machine (int): number of GPUs per machine |
| | num_machines (int): the total number of machines |
| | machine_rank (int): the rank of this machine |
| | dist_url (str): url to connect to for distributed jobs, including protocol |
| | e.g. "tcp://127.0.0.1:8686". |
| | Can be set to "auto" to automatically select a free port on localhost |
| | timeout (timedelta): timeout of the distributed workers |
| | args (tuple): arguments passed to main_func |
| | """ |
| | world_size = num_machines * num_gpus_per_machine |
| | if world_size > 1: |
| | if dist_url == "auto": |
| | assert ( |
| | num_machines == 1 |
| | ), "dist_url=auto not supported in multi-machine jobs." |
| | port = _find_free_port() |
| | dist_url = f"tcp://127.0.0.1:{port}" |
| | if num_machines > 1 and dist_url.startswith("file://"): |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" |
| | ) |
| |
|
| | mp.spawn( |
| | _distributed_worker, |
| | nprocs=num_gpus_per_machine, |
| | args=( |
| | main_func, |
| | world_size, |
| | num_gpus_per_machine, |
| | machine_rank, |
| | dist_url, |
| | cfg, |
| | timeout, |
| | ), |
| | daemon=False, |
| | ) |
| | else: |
| | main_func(*cfg) |
| |
|
| |
|
| | def _distributed_worker( |
| | local_rank, |
| | main_func, |
| | world_size, |
| | num_gpus_per_machine, |
| | machine_rank, |
| | dist_url, |
| | cfg, |
| | timeout=DEFAULT_TIMEOUT, |
| | ): |
| | assert ( |
| | torch.cuda.is_available() |
| | ), "cuda is not available. Please check your installation." |
| | global_rank = machine_rank * num_gpus_per_machine + local_rank |
| | try: |
| | dist.init_process_group( |
| | backend="NCCL", |
| | init_method=dist_url, |
| | world_size=world_size, |
| | rank=global_rank, |
| | timeout=timeout, |
| | ) |
| | except Exception as e: |
| | logger = logging.getLogger(__name__) |
| | logger.error("Process group URL: {}".format(dist_url)) |
| | raise e |
| |
|
| | |
| | assert comm._LOCAL_PROCESS_GROUP is None |
| | num_machines = world_size // num_gpus_per_machine |
| | for i in range(num_machines): |
| | ranks_on_i = list( |
| | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) |
| | ) |
| | pg = dist.new_group(ranks_on_i) |
| | if i == machine_rank: |
| | comm._LOCAL_PROCESS_GROUP = pg |
| |
|
| | assert num_gpus_per_machine <= torch.cuda.device_count() |
| | torch.cuda.set_device(local_rank) |
| |
|
| | |
| | |
| | comm.synchronize() |
| |
|
| | main_func(*cfg) |
| |
|