|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
|
import functools |
|
|
import os |
|
|
import subprocess |
|
|
from typing import Callable, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.multiprocessing as mp |
|
|
from torch import Tensor |
|
|
from torch import distributed as torch_dist |
|
|
from torch.distributed import ProcessGroup |
|
|
from mmengine.device import is_mlu_available, is_npu_available |
|
|
|
|
|
from collections.abc import Iterable, Mapping |
|
|
|
|
|
_LOCAL_PROCESS_GROUP = None |
|
|
|
|
|
|
|
|
def is_distributed() -> bool: |
|
|
"""Return True if distributed environment has been initialized.""" |
|
|
return torch_dist.is_available() and torch_dist.is_initialized() |
|
|
|
|
|
|
|
|
def get_local_group() -> Optional[ProcessGroup]: |
|
|
"""Return local process group.""" |
|
|
if not is_distributed(): |
|
|
return None |
|
|
|
|
|
if _LOCAL_PROCESS_GROUP is None: |
|
|
raise RuntimeError('Local process group is not created, please use ' |
|
|
'`init_local_group` to setup local process group.') |
|
|
|
|
|
return _LOCAL_PROCESS_GROUP |
|
|
|
|
|
|
|
|
def get_default_group() -> Optional[ProcessGroup]: |
|
|
"""Return default process group.""" |
|
|
|
|
|
return torch_dist.distributed_c10d._get_default_group() |
|
|
|
|
|
|
|
|
def infer_launcher(): |
|
|
if 'WORLD_SIZE' in os.environ: |
|
|
return 'pytorch' |
|
|
elif 'SLURM_NTASKS' in os.environ: |
|
|
return 'slurm' |
|
|
elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: |
|
|
return 'mpi' |
|
|
else: |
|
|
return 'none' |
|
|
|
|
|
|
|
|
def init_dist(launcher, |
|
|
backend='nccl', |
|
|
init_backend='torch', |
|
|
**kwargs) -> None: |
|
|
"""Initialize distributed environment. |
|
|
|
|
|
Args: |
|
|
launcher (str): Way to launcher multi processes. Supported launchers |
|
|
are 'pytorch', 'mpi' and 'slurm'. |
|
|
backend (str): Communication Backends. Supported backends are 'nccl', |
|
|
'gloo' and 'mpi'. Defaults to 'nccl'. |
|
|
**kwargs: keyword arguments are passed to ``init_process_group``. |
|
|
""" |
|
|
timeout = kwargs.get('timeout', None) |
|
|
if timeout is not None: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
kwargs['timeout'] = datetime.timedelta(seconds=timeout) |
|
|
except TypeError as exception: |
|
|
raise TypeError( |
|
|
f'Timeout for distributed training must be provided as ' |
|
|
f"timeout in seconds, but we've received the type " |
|
|
f'{type(timeout)}. Please specify the timeout like this: ' |
|
|
f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception |
|
|
if mp.get_start_method(allow_none=True) is None: |
|
|
mp.set_start_method('spawn') |
|
|
if launcher == 'pytorch': |
|
|
_init_dist_pytorch(backend, init_backend=init_backend, **kwargs) |
|
|
elif launcher == 'mpi': |
|
|
_init_dist_mpi(backend, **kwargs) |
|
|
elif launcher == 'slurm': |
|
|
_init_dist_slurm(backend, init_backend=init_backend, **kwargs) |
|
|
else: |
|
|
raise ValueError(f'Invalid launcher type: {launcher}') |
|
|
|
|
|
|
|
|
def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: |
|
|
"""Initialize distributed environment with PyTorch launcher. |
|
|
|
|
|
Args: |
|
|
backend (str): Backend of torch.distributed. Supported backends are |
|
|
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
|
|
**kwargs: keyword arguments are passed to ``init_process_group``. |
|
|
""" |
|
|
rank = int(os.environ['RANK']) |
|
|
if is_mlu_available(): |
|
|
import torch_mlu |
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
|
torch.mlu.set_device(local_rank) |
|
|
torch_dist.init_process_group( |
|
|
backend='cncl', |
|
|
rank=rank, |
|
|
world_size=int(os.environ['WORLD_SIZE']), |
|
|
**kwargs) |
|
|
elif is_npu_available(): |
|
|
import torch_npu |
|
|
torch.npu.set_device(rank) |
|
|
torch_dist.init_process_group( |
|
|
backend='hccl', |
|
|
rank=rank, |
|
|
world_size=int(os.environ['WORLD_SIZE']), |
|
|
**kwargs) |
|
|
else: |
|
|
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
if init_backend == 'torch': |
|
|
torch_dist.init_process_group(backend=backend, **kwargs) |
|
|
elif init_backend == 'deepspeed': |
|
|
import deepspeed |
|
|
deepspeed.init_distributed(dist_backend=backend, **kwargs) |
|
|
elif init_backend == 'colossalai': |
|
|
import colossalai |
|
|
colossalai.launch_from_torch(backend=backend, **kwargs) |
|
|
else: |
|
|
raise ValueError( |
|
|
'supported "init_backend" is "torch" or "deepspeed", ' |
|
|
f'but got {init_backend}') |
|
|
|
|
|
|
|
|
def _init_dist_mpi(backend, **kwargs) -> None: |
|
|
"""Initialize distributed environment with MPI launcher. |
|
|
|
|
|
Args: |
|
|
backend (str): Backend of torch.distributed. Supported backends are |
|
|
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
|
|
**kwargs: keyword arguments are passed to ``init_process_group``. |
|
|
""" |
|
|
if backend == 'smddp': |
|
|
try: |
|
|
import smdistributed.dataparallel.torch.torch_smddp |
|
|
except ModuleNotFoundError as e: |
|
|
raise ModuleNotFoundError( |
|
|
'Please use an Amazon SageMaker DLC to access smdistributed: ' |
|
|
'https://github.com/aws/deep-learning-containers/blob/master' |
|
|
'/available_images.md#sagemaker-framework-containers' |
|
|
'-sm-support-only') from e |
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
|
|
torch.cuda.set_device(local_rank) |
|
|
if 'MASTER_PORT' not in os.environ: |
|
|
|
|
|
os.environ['MASTER_PORT'] = '29500' |
|
|
if 'MASTER_ADDR' not in os.environ: |
|
|
raise KeyError('The environment variable MASTER_ADDR is not set') |
|
|
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] |
|
|
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] |
|
|
torch_dist.init_process_group(backend=backend, **kwargs) |
|
|
|
|
|
|
|
|
def _init_dist_slurm(backend, |
|
|
port=None, |
|
|
init_backend='torch', |
|
|
**kwargs) -> None: |
|
|
"""Initialize slurm distributed training environment. |
|
|
|
|
|
If argument ``port`` is not specified, then the master port will be system |
|
|
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system |
|
|
environment variable, then a default port ``29500`` will be used. |
|
|
|
|
|
Args: |
|
|
backend (str): Backend of torch.distributed. |
|
|
port (int, optional): Master port. Defaults to None. |
|
|
""" |
|
|
proc_id = int(os.environ['SLURM_PROCID']) |
|
|
ntasks = int(os.environ['SLURM_NTASKS']) |
|
|
node_list = os.environ['SLURM_NODELIST'] |
|
|
|
|
|
local_rank_env = os.environ.get('SLURM_LOCALID', None) |
|
|
if local_rank_env is not None: |
|
|
local_rank = int(local_rank_env) |
|
|
else: |
|
|
num_gpus = torch.cuda.device_count() |
|
|
local_rank = proc_id % num_gpus |
|
|
torch.cuda.set_device(local_rank) |
|
|
addr = subprocess.getoutput( |
|
|
f'scontrol show hostname {node_list} | head -n1') |
|
|
|
|
|
if port is not None: |
|
|
os.environ['MASTER_PORT'] = str(port) |
|
|
elif 'MASTER_PORT' in os.environ: |
|
|
pass |
|
|
else: |
|
|
|
|
|
os.environ['MASTER_PORT'] = '29500' |
|
|
|
|
|
if 'MASTER_ADDR' not in os.environ: |
|
|
os.environ['MASTER_ADDR'] = addr |
|
|
os.environ['WORLD_SIZE'] = str(ntasks) |
|
|
os.environ['LOCAL_RANK'] = str(local_rank) |
|
|
os.environ['RANK'] = str(proc_id) |
|
|
|
|
|
if init_backend == 'torch': |
|
|
torch_dist.init_process_group(backend=backend, **kwargs) |
|
|
elif init_backend == 'deepspeed': |
|
|
import deepspeed |
|
|
deepspeed.init_distributed(dist_backend=backend, **kwargs) |
|
|
elif init_backend == 'colossalai': |
|
|
import colossalai |
|
|
colossalai.launch_from_slurm( |
|
|
backend=backend, |
|
|
host=os.environ['MASTER_ADDR'], |
|
|
port=os.environ['MASTER_PORT'], |
|
|
**kwargs, |
|
|
) |
|
|
else: |
|
|
raise ValueError('supported "init_backend" is "torch" or "deepspeed", ' |
|
|
f'but got {init_backend}') |
|
|
|
|
|
|
|
|
def init_local_group(node_rank: int, num_gpus_per_node: int): |
|
|
"""Setup the local process group. |
|
|
|
|
|
Setup a process group which only includes processes that on the same |
|
|
machine as the current process. |
|
|
|
|
|
The code is modified from |
|
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py |
|
|
|
|
|
Args: |
|
|
node_rank (int): Rank of machines used for training. |
|
|
num_gpus_per_node (int): Number of gpus used for training in a single |
|
|
machine. |
|
|
""" |
|
|
global _LOCAL_PROCESS_GROUP |
|
|
assert _LOCAL_PROCESS_GROUP is None |
|
|
|
|
|
ranks = list( |
|
|
range(node_rank * num_gpus_per_node, |
|
|
(node_rank + 1) * num_gpus_per_node)) |
|
|
_LOCAL_PROCESS_GROUP = torch_dist.new_group(ranks) |
|
|
|
|
|
|
|
|
def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]: |
|
|
"""Return the backend of the given process group. |
|
|
|
|
|
Note: |
|
|
Calling ``get_backend`` in non-distributed environment will return |
|
|
None. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. The |
|
|
default is the general main process group. If another specific |
|
|
group is specified, the calling process must be part of |
|
|
:attr:`group`. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
str or None: Return the backend of the given process group as a lower |
|
|
case string if in distributed environment, otherwise None. |
|
|
""" |
|
|
if is_distributed(): |
|
|
|
|
|
|
|
|
if group is None: |
|
|
group = get_default_group() |
|
|
return torch_dist.get_backend(group) |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def get_world_size(group: Optional[ProcessGroup] = None) -> int: |
|
|
"""Return the number of the given process group. |
|
|
|
|
|
Note: |
|
|
Calling ``get_world_size`` in non-distributed environment will return |
|
|
1. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
the default process group will be used. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
int: Return the number of processes of the given process group if in |
|
|
distributed environment, otherwise 1. |
|
|
""" |
|
|
if is_distributed(): |
|
|
|
|
|
|
|
|
if group is None: |
|
|
group = get_default_group() |
|
|
return torch_dist.get_world_size(group) |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
|
|
|
def get_rank(group: Optional[ProcessGroup] = None) -> int: |
|
|
"""Return the rank of the given process group. |
|
|
|
|
|
Rank is a unique identifier assigned to each process within a distributed |
|
|
process group. They are always consecutive integers ranging from 0 to |
|
|
``world_size``. |
|
|
|
|
|
Note: |
|
|
Calling ``get_rank`` in non-distributed environment will return 0. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
the default process group will be used. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
int: Return the rank of the process group if in distributed |
|
|
environment, otherwise 0. |
|
|
""" |
|
|
|
|
|
if is_distributed(): |
|
|
|
|
|
|
|
|
if group is None: |
|
|
group = get_default_group() |
|
|
return torch_dist.get_rank(group) |
|
|
else: |
|
|
return 0 |
|
|
|
|
|
|
|
|
def get_local_size() -> int: |
|
|
"""Return the number of the current node. |
|
|
|
|
|
Returns: |
|
|
int: Return the number of processes in the current node if in |
|
|
distributed environment, otherwise 1. |
|
|
""" |
|
|
if not is_distributed(): |
|
|
return 1 |
|
|
|
|
|
if _LOCAL_PROCESS_GROUP is None: |
|
|
raise RuntimeError('Local process group is not created, please use ' |
|
|
'`init_local_group` to setup local process group.') |
|
|
|
|
|
return torch_dist.get_world_size(_LOCAL_PROCESS_GROUP) |
|
|
|
|
|
|
|
|
def get_local_rank() -> int: |
|
|
"""Return the rank of current process in the current node. |
|
|
|
|
|
Returns: |
|
|
int: Return the rank of current process in the current node if in |
|
|
distributed environment, otherwise 0 |
|
|
""" |
|
|
if not is_distributed(): |
|
|
return 0 |
|
|
|
|
|
if _LOCAL_PROCESS_GROUP is None: |
|
|
raise RuntimeError('Local process group is not created, please use ' |
|
|
'`init_local_group` to setup local process group.') |
|
|
|
|
|
return torch_dist.get_rank(_LOCAL_PROCESS_GROUP) |
|
|
|
|
|
|
|
|
def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]: |
|
|
"""Get distributed information of the given process group. |
|
|
|
|
|
Note: |
|
|
Calling ``get_dist_info`` in non-distributed environment will return |
|
|
(0, 1). |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
the default process group will be used. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
tuple[int, int]: Return a tuple containing the ``rank`` and |
|
|
``world_size``. |
|
|
""" |
|
|
world_size = get_world_size(group) |
|
|
rank = get_rank(group) |
|
|
return rank, world_size |
|
|
|
|
|
|
|
|
def is_main_process(group: Optional[ProcessGroup] = None) -> bool: |
|
|
"""Whether the current rank of the given process group is equal to 0. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
the default process group will be used. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
bool: Return True if the current rank of the given process group is |
|
|
equal to 0, otherwise False. |
|
|
""" |
|
|
return get_rank(group) == 0 |
|
|
|
|
|
|
|
|
def master_only(func: Callable) -> Callable: |
|
|
"""Decorate those methods which should be executed in master process. |
|
|
|
|
|
Args: |
|
|
func (callable): Function to be decorated. |
|
|
|
|
|
Returns: |
|
|
callable: Return decorated function. |
|
|
""" |
|
|
|
|
|
@functools.wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
if is_main_process(): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
def barrier(group: Optional[ProcessGroup] = None) -> None: |
|
|
"""Synchronize all processes from the given process group. |
|
|
|
|
|
This collective blocks processes until the whole group enters this |
|
|
function. |
|
|
|
|
|
Note: |
|
|
Calling ``barrier`` in non-distributed environment will do nothing. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
the default process group will be used. Defaults to None. |
|
|
""" |
|
|
if is_distributed(): |
|
|
|
|
|
|
|
|
if group is None: |
|
|
group = get_default_group() |
|
|
torch_dist.barrier(group) |
|
|
|
|
|
|
|
|
def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device: |
|
|
"""Return the device of ``data``. |
|
|
|
|
|
If ``data`` is a sequence of Tensor, all items in ``data`` should have a |
|
|
same device type. |
|
|
|
|
|
If ``data`` is a dict whose values are Tensor, all values should have a |
|
|
same device type. |
|
|
|
|
|
Args: |
|
|
data (Tensor or Sequence or dict): Inputs to be inferred the device. |
|
|
|
|
|
Returns: |
|
|
torch.device: The device of ``data``. |
|
|
|
|
|
Examples: |
|
|
>>> import torch |
|
|
>>> from mmengine.dist import cast_data_device |
|
|
>>> # data is a Tensor |
|
|
>>> data = torch.tensor([0, 1]) |
|
|
>>> get_data_device(data) |
|
|
device(type='cpu') |
|
|
>>> # data is a list of Tensor |
|
|
>>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])] |
|
|
>>> get_data_device(data) |
|
|
device(type='cpu') |
|
|
>>> # data is a dict |
|
|
>>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} |
|
|
>>> get_data_device(data) |
|
|
device(type='cpu') |
|
|
""" |
|
|
if isinstance(data, Tensor): |
|
|
return data.device |
|
|
elif isinstance(data, Mapping): |
|
|
pre = None |
|
|
for v in data.values(): |
|
|
cur = get_data_device(v) |
|
|
if pre is None: |
|
|
pre = cur |
|
|
else: |
|
|
if cur != pre: |
|
|
raise ValueError( |
|
|
'device type in data should be consistent, but got ' |
|
|
f'{cur} and {pre}') |
|
|
if pre is None: |
|
|
raise ValueError('data should not be empty.') |
|
|
return pre |
|
|
elif isinstance(data, Iterable) and not isinstance(data, str): |
|
|
pre = None |
|
|
for item in data: |
|
|
cur = get_data_device(item) |
|
|
if pre is None: |
|
|
pre = cur |
|
|
else: |
|
|
if cur != pre: |
|
|
raise ValueError( |
|
|
'device type in data should be consistent, but got ' |
|
|
f'{cur} and {pre}') |
|
|
if pre is None: |
|
|
raise ValueError('data should not be empty.') |
|
|
return pre |
|
|
else: |
|
|
raise TypeError('data should be a Tensor, sequence of tensor or dict, ' |
|
|
f'but got {data}') |
|
|
|
|
|
|
|
|
def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: |
|
|
"""Return the device for communication among groups. |
|
|
|
|
|
Args: |
|
|
group (ProcessGroup, optional): The process group to work on. |
|
|
|
|
|
Returns: |
|
|
torch.device: The device of backend. |
|
|
""" |
|
|
backend = get_backend(group) |
|
|
if backend == 'hccl': |
|
|
import torch_npu |
|
|
return torch.device('npu', torch.npu.current_device()) |
|
|
elif backend == torch_dist.Backend.NCCL: |
|
|
return torch.device('cuda', torch.cuda.current_device()) |
|
|
elif backend == 'cncl': |
|
|
import torch_mlu |
|
|
return torch.device('mlu', torch.mlu.current_device()) |
|
|
elif backend == 'smddp': |
|
|
return torch.device('cuda', torch.cuda.current_device()) |
|
|
else: |
|
|
|
|
|
return torch.device('cpu') |
|
|
|
|
|
|
|
|
def cast_data_device( |
|
|
data: Union[Tensor, Mapping, Iterable], |
|
|
device: torch.device, |
|
|
out: Optional[Union[Tensor, Mapping, Iterable]] = None |
|
|
) -> Union[Tensor, Mapping, Iterable]: |
|
|
"""Recursively convert Tensor in ``data`` to ``device``. |
|
|
|
|
|
If ``data`` has already on the ``device``, it will not be casted again. |
|
|
|
|
|
Args: |
|
|
data (Tensor or list or dict): Inputs to be casted. |
|
|
device (torch.device): Destination device type. |
|
|
out (Tensor or list or dict, optional): If ``out`` is specified, its |
|
|
value will be equal to ``data``. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tensor or list or dict: ``data`` was casted to ``device``. |
|
|
""" |
|
|
if out is not None: |
|
|
if type(data) != type(out): |
|
|
raise TypeError( |
|
|
'out should be the same type with data, but got data is ' |
|
|
f'{type(data)} and out is {type(data)}') |
|
|
|
|
|
if isinstance(out, set): |
|
|
raise TypeError('out should not be a set') |
|
|
|
|
|
if isinstance(data, Tensor): |
|
|
if get_data_device(data) == device: |
|
|
data_on_device = data |
|
|
else: |
|
|
data_on_device = data.to(device) |
|
|
|
|
|
if out is not None: |
|
|
|
|
|
out.copy_(data_on_device) |
|
|
|
|
|
return data_on_device |
|
|
elif isinstance(data, Mapping): |
|
|
data_on_device = {} |
|
|
if out is not None: |
|
|
data_len = len(data) |
|
|
out_len = len(out) |
|
|
if data_len != out_len: |
|
|
raise ValueError('length of data and out should be same, ' |
|
|
f'but got {data_len} and {out_len}') |
|
|
|
|
|
for k, v in data.items(): |
|
|
data_on_device[k] = cast_data_device(v, device, |
|
|
out[k]) |
|
|
else: |
|
|
for k, v in data.items(): |
|
|
data_on_device[k] = cast_data_device(v, device) |
|
|
|
|
|
if len(data_on_device) == 0: |
|
|
raise ValueError('data should not be empty') |
|
|
|
|
|
|
|
|
|
|
|
return type(data)(data_on_device) |
|
|
elif isinstance(data, Iterable) and not isinstance( |
|
|
data, str) and not isinstance(data, np.ndarray): |
|
|
data_on_device = [] |
|
|
if out is not None: |
|
|
for v1, v2 in zip(data, out): |
|
|
data_on_device.append(cast_data_device(v1, device, v2)) |
|
|
else: |
|
|
for v in data: |
|
|
data_on_device.append(cast_data_device(v, device)) |
|
|
|
|
|
if len(data_on_device) == 0: |
|
|
raise ValueError('data should not be empty') |
|
|
|
|
|
return type(data)(data_on_device) |
|
|
else: |
|
|
raise TypeError('data should be a Tensor, list of tensor or dict, ' |
|
|
f'but got {data}') |
|
|
|