| | |
| | """ |
| | This file contains primitives for multi-gpu communication. |
| | This is useful when doing distributed training. |
| | """ |
| |
|
| | import functools |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| |
|
| | _LOCAL_PROCESS_GROUP = None |
| | _MISSING_LOCAL_PG_ERROR = ( |
| | "Local process group is not yet created! Please use detectron2's `launch()` " |
| | "to start processes and initialize pytorch process group. If you need to start " |
| | "processes in other ways, please call comm.create_local_process_group(" |
| | "num_workers_per_machine) after calling torch.distributed.init_process_group()." |
| | ) |
| |
|
| |
|
| | def get_world_size() -> int: |
| | if not dist.is_available(): |
| | return 1 |
| | if not dist.is_initialized(): |
| | return 1 |
| | return dist.get_world_size() |
| |
|
| |
|
| | def get_rank() -> int: |
| | if not dist.is_available(): |
| | return 0 |
| | if not dist.is_initialized(): |
| | return 0 |
| | return dist.get_rank() |
| |
|
| |
|
| | @functools.lru_cache() |
| | def create_local_process_group(num_workers_per_machine: int) -> None: |
| | """ |
| | Create a process group that contains ranks within the same machine. |
| | |
| | Detectron2's launch() in engine/launch.py will call this function. If you start |
| | workers without launch(), you'll have to also call this. Otherwise utilities |
| | like `get_local_rank()` will not work. |
| | |
| | This function contains a barrier. All processes must call it together. |
| | |
| | Args: |
| | num_workers_per_machine: the number of worker processes per machine. Typically |
| | the number of GPUs. |
| | """ |
| | global _LOCAL_PROCESS_GROUP |
| | assert _LOCAL_PROCESS_GROUP is None |
| | assert get_world_size() % num_workers_per_machine == 0 |
| | num_machines = get_world_size() // num_workers_per_machine |
| | machine_rank = get_rank() // num_workers_per_machine |
| | for i in range(num_machines): |
| | ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)) |
| | pg = dist.new_group(ranks_on_i) |
| | if i == machine_rank: |
| | _LOCAL_PROCESS_GROUP = pg |
| |
|
| |
|
| | def get_local_process_group(): |
| | """ |
| | Returns: |
| | A torch process group which only includes processes that are on the same |
| | machine as the current process. This group can be useful for communication |
| | within a machine, e.g. a per-machine SyncBN. |
| | """ |
| | assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR |
| | return _LOCAL_PROCESS_GROUP |
| |
|
| |
|
| | def get_local_rank() -> int: |
| | """ |
| | Returns: |
| | The rank of the current process within the local (per-machine) process group. |
| | """ |
| | if not dist.is_available(): |
| | return 0 |
| | if not dist.is_initialized(): |
| | return 0 |
| | assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR |
| | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) |
| |
|
| |
|
| | def get_local_size() -> int: |
| | """ |
| | Returns: |
| | The size of the per-machine process group, |
| | i.e. the number of processes per machine. |
| | """ |
| | if not dist.is_available(): |
| | return 1 |
| | if not dist.is_initialized(): |
| | return 1 |
| | assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR |
| | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) |
| |
|
| |
|
| | def is_main_process() -> bool: |
| | return get_rank() == 0 |
| |
|
| |
|
| | def synchronize(): |
| | """ |
| | Helper function to synchronize (barrier) among all processes when |
| | using distributed training |
| | """ |
| | if not dist.is_available(): |
| | return |
| | if not dist.is_initialized(): |
| | return |
| | world_size = dist.get_world_size() |
| | if world_size == 1: |
| | return |
| | if dist.get_backend() == dist.Backend.NCCL: |
| | |
| | |
| | dist.barrier(device_ids=[torch.cuda.current_device()]) |
| | else: |
| | dist.barrier() |
| |
|
| |
|
| | @functools.lru_cache() |
| | def _get_global_gloo_group(): |
| | """ |
| | Return a process group based on gloo backend, containing all the ranks |
| | The result is cached. |
| | """ |
| | if dist.get_backend() == "nccl": |
| | return dist.new_group(backend="gloo") |
| | else: |
| | return dist.group.WORLD |
| |
|
| |
|
| | def all_gather(data, group=None): |
| | """ |
| | Run all_gather on arbitrary picklable data (not necessarily tensors). |
| | |
| | Args: |
| | data: any picklable object |
| | group: a torch process group. By default, will use a group which |
| | contains all ranks on gloo backend. |
| | |
| | Returns: |
| | list[data]: list of data gathered from each rank |
| | """ |
| | if get_world_size() == 1: |
| | return [data] |
| | if group is None: |
| | group = _get_global_gloo_group() |
| | world_size = dist.get_world_size(group) |
| | if world_size == 1: |
| | return [data] |
| |
|
| | output = [None for _ in range(world_size)] |
| | dist.all_gather_object(output, data, group=group) |
| | return output |
| |
|
| |
|
| | def gather(data, dst=0, group=None): |
| | """ |
| | Run gather on arbitrary picklable data (not necessarily tensors). |
| | |
| | Args: |
| | data: any picklable object |
| | dst (int): destination rank |
| | group: a torch process group. By default, will use a group which |
| | contains all ranks on gloo backend. |
| | |
| | Returns: |
| | list[data]: on dst, a list of data gathered from each rank. Otherwise, |
| | an empty list. |
| | """ |
| | if get_world_size() == 1: |
| | return [data] |
| | if group is None: |
| | group = _get_global_gloo_group() |
| | world_size = dist.get_world_size(group=group) |
| | if world_size == 1: |
| | return [data] |
| | rank = dist.get_rank(group=group) |
| |
|
| | if rank == dst: |
| | output = [None for _ in range(world_size)] |
| | dist.gather_object(data, output, dst=dst, group=group) |
| | return output |
| | else: |
| | dist.gather_object(data, None, dst=dst, group=group) |
| | return [] |
| |
|
| |
|
| | def shared_random_seed(): |
| | """ |
| | Returns: |
| | int: a random number that is the same across all workers. |
| | If workers need a shared RNG, they can use this shared seed to |
| | create one. |
| | |
| | All workers must call this function, otherwise it will deadlock. |
| | """ |
| | ints = np.random.randint(2**31) |
| | all_ints = all_gather(ints) |
| | return all_ints[0] |
| |
|
| |
|
| | def reduce_dict(input_dict, average=True): |
| | """ |
| | Reduce the values in the dictionary from all processes so that process with rank |
| | 0 has the reduced results. |
| | |
| | Args: |
| | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. |
| | average (bool): whether to do average or sum |
| | |
| | Returns: |
| | a dict with the same keys as input_dict, after reduction. |
| | """ |
| | world_size = get_world_size() |
| | if world_size < 2: |
| | return input_dict |
| | with torch.no_grad(): |
| | names = [] |
| | values = [] |
| | |
| | for k in sorted(input_dict.keys()): |
| | names.append(k) |
| | values.append(input_dict[k]) |
| | values = torch.stack(values, dim=0) |
| | dist.reduce(values, dst=0) |
| | if dist.get_rank() == 0 and average: |
| | |
| | |
| | values /= world_size |
| | reduced_dict = {k: v for k, v in zip(names, values)} |
| | return reduced_dict |
| |
|