| |
| import logging |
| from datetime import timedelta |
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
|
|
| from detectron2.utils import comm |
|
|
| __all__ = ["DEFAULT_TIMEOUT", "launch"] |
|
|
| DEFAULT_TIMEOUT = timedelta(minutes=30) |
|
|
|
|
| def _find_free_port(): |
| import socket |
|
|
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| |
| sock.bind(("", 0)) |
| port = sock.getsockname()[1] |
| sock.close() |
| |
| return port |
|
|
|
|
| def launch( |
| main_func, |
| |
| num_gpus_per_machine, |
| num_machines=1, |
| machine_rank=0, |
| dist_url=None, |
| args=(), |
| timeout=DEFAULT_TIMEOUT, |
| ): |
| """ |
| Launch multi-process or distributed training. |
| This function must be called on all machines involved in the training. |
| It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. |
| |
| Args: |
| main_func: a function that will be called by `main_func(*args)` |
| num_gpus_per_machine (int): number of processes per machine. When |
| using GPUs, this should be the number of GPUs. |
| num_machines (int): the total number of machines |
| machine_rank (int): the rank of this machine |
| dist_url (str): url to connect to for distributed jobs, including protocol |
| e.g. "tcp://127.0.0.1:8686". |
| Can be set to "auto" to automatically select a free port on localhost |
| timeout (timedelta): timeout of the distributed workers |
| args (tuple): arguments passed to main_func |
| """ |
| world_size = num_machines * num_gpus_per_machine |
| if world_size > 1: |
| |
| |
|
|
| if dist_url == "auto": |
| assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." |
| port = _find_free_port() |
| dist_url = f"tcp://127.0.0.1:{port}" |
| if num_machines > 1 and dist_url.startswith("file://"): |
| logger = logging.getLogger(__name__) |
| logger.warning( |
| "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" |
| ) |
|
|
| mp.start_processes( |
| _distributed_worker, |
| nprocs=num_gpus_per_machine, |
| args=( |
| main_func, |
| world_size, |
| num_gpus_per_machine, |
| machine_rank, |
| dist_url, |
| args, |
| timeout, |
| ), |
| daemon=False, |
| ) |
| else: |
| main_func(*args) |
|
|
|
|
| def _distributed_worker( |
| local_rank, |
| main_func, |
| world_size, |
| num_gpus_per_machine, |
| machine_rank, |
| dist_url, |
| args, |
| timeout=DEFAULT_TIMEOUT, |
| ): |
| has_gpu = torch.cuda.is_available() |
| if has_gpu: |
| assert num_gpus_per_machine <= torch.cuda.device_count() |
| global_rank = machine_rank * num_gpus_per_machine + local_rank |
| try: |
| dist.init_process_group( |
| backend="NCCL" if has_gpu else "GLOO", |
| init_method=dist_url, |
| world_size=world_size, |
| rank=global_rank, |
| timeout=timeout, |
| ) |
| except Exception as e: |
| logger = logging.getLogger(__name__) |
| logger.error("Process group URL: {}".format(dist_url)) |
| raise e |
|
|
| |
| comm.create_local_process_group(num_gpus_per_machine) |
| if has_gpu: |
| torch.cuda.set_device(local_rank) |
|
|
| |
| |
| comm.synchronize() |
|
|
| main_func(*args) |
|
|