Spaces:
Sleeping
Sleeping
| from functools import lru_cache | |
| from typing import Callable, Tuple, List, Any | |
| import numpy as np | |
| import torch | |
| from .default_helper import error_wrapper | |
| from .fake_linklink import FakeLink | |
| from .import_helper import try_import_link | |
| def get_link(): | |
| return try_import_link() | |
| def is_fake_link(): | |
| return isinstance(get_link(), FakeLink) | |
| def get_rank() -> int: | |
| """ | |
| Overview: | |
| Get the rank of ``linklink`` model, return 0 if use ``FakeLink``. | |
| .. note:: | |
| Reference ``import_helper.try_import_link`` and ``linklink.get_rank``. | |
| """ | |
| if is_fake_link(): | |
| return 0 | |
| return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")() | |
| def get_world_size() -> int: | |
| """ | |
| Overview: | |
| Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``. | |
| .. note:: | |
| Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``. | |
| """ | |
| if is_fake_link(): | |
| return 1 | |
| return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")() | |
| def broadcast(value: torch.Tensor, rank: int) -> None: | |
| """ | |
| Overview: | |
| Use ``linklink.broadcast`` and raise error when using ``FakeLink`` | |
| Arguments: | |
| - value (:obj:`obj`): the value to board cast | |
| - rank (:obj:`int`): the rank to broadcast on | |
| """ | |
| if is_fake_link(): | |
| raise NotImplementedError | |
| get_link().broadcast(value, rank) | |
| def allreduce(data: torch.Tensor, op: str = 'sum') -> None: | |
| """ | |
| Overview: | |
| Call ``linklink.allreduce`` on the data | |
| Arguments: | |
| - data (:obj:`obj`): the data to reduce | |
| - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` | |
| """ | |
| link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} | |
| if op not in link_op_map.keys(): | |
| raise KeyError("not support allreduce op type: {}".format(op)) | |
| else: | |
| link_op = link_op_map[op] | |
| if is_fake_link(): | |
| return data | |
| get_link().allreduce(data, reduce_op=link_op) | |
| if op == 'sum': | |
| data.div_(get_world_size()) | |
| def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: | |
| """ | |
| Overview: | |
| Call ``linklink.allreduce_async`` on the data | |
| Arguments: | |
| - data (:obj:`obj`): the data to reduce | |
| - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` | |
| """ | |
| link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} | |
| if op not in link_op_map.keys(): | |
| raise KeyError("not support allreduce op type: {}".format(op)) | |
| else: | |
| link_op = link_op_map[op] | |
| if is_fake_link(): | |
| return data | |
| if op == 'sum': | |
| data.div_(get_world_size()) | |
| get_link().allreduce_async(data, reduce_op=link_op) | |
| def get_group(group_size: int) -> List: | |
| """ | |
| Overview: | |
| Get the group segmentation of ``group_size`` each group | |
| Arguments: | |
| - group_size (:obj:`int`) the ``group_size`` | |
| """ | |
| rank = get_rank() | |
| world_size = get_world_size() | |
| if group_size is None: | |
| group_size = world_size | |
| assert (world_size % group_size == 0) | |
| return simple_group_split(world_size, rank, world_size // group_size) | |
| def dist_mode(func: Callable) -> Callable: | |
| """ | |
| Overview: | |
| Wrap the function so that in can init and finalize automatically before each call | |
| Arguments: | |
| - func (:obj:`Callable`): the function to wrap | |
| """ | |
| def wrapper(*args, **kwargs): | |
| dist_init() | |
| func(*args, **kwargs) | |
| dist_finalize() | |
| return wrapper | |
| def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: | |
| """ | |
| Overview: | |
| Init the distribution | |
| Arguments: | |
| - method (:obj:`str`): Support ``['slurm', 'single_node`]`` | |
| - device_id (:obj:`int`): Default device when using ``single_node`` method | |
| """ | |
| get_link().initialize() | |
| world_size = get_link().get_world_size() | |
| rank = get_link().get_rank() | |
| if method == 'slurm': | |
| # proc_id = int(os.environ['SLURM_PROCID']) | |
| # ntasks = int(os.environ['SLURM_NTASKS']) | |
| # node_list = os.environ['SLURM_NODELIST'] | |
| num_gpus = torch.cuda.device_count() | |
| torch.cuda.set_device(rank % num_gpus) | |
| elif method == 'single_node': | |
| torch.cuda.set_device(device_id) | |
| return rank, world_size | |
| def dist_finalize() -> None: | |
| """ | |
| Overview: | |
| Finalize ``linklink``, see ``linklink.finalize()`` | |
| """ | |
| get_link().finalize() | |
| class DistContext: | |
| """ | |
| Overview: | |
| A context manager for ``linklink`` distribution | |
| Interfaces: | |
| ``__init__``, ``__enter__``, ``__exit__`` | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the ``DistContext`` | |
| """ | |
| pass | |
| def __enter__(self) -> None: | |
| """ | |
| Overview: | |
| Initialize ``linklink`` distribution | |
| """ | |
| dist_init() | |
| def __exit__(self, *args, **kwargs) -> Any: | |
| """ | |
| Overview: | |
| Finalize ``linklink`` distribution | |
| Arugments: | |
| - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. | |
| """ | |
| dist_finalize() | |
| def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: | |
| """ | |
| Overview: | |
| Split the group according to ``worldsize``, ``rank`` and ``num_groups`` | |
| Arguments: | |
| - world_size (:obj:`int`): The world size | |
| - rank (:obj:`int`): The rank | |
| - num_groups (:obj:`int`): The number of groups | |
| .. note:: | |
| With faulty input, raise ``array split does not result in an equal division`` | |
| """ | |
| groups = [] | |
| rank_list = np.split(np.arange(world_size), num_groups) | |
| rank_list = [list(map(int, x)) for x in rank_list] | |
| for i in range(num_groups): | |
| groups.append(get_link().new_group(rank_list[i])) | |
| group_size = world_size // num_groups | |
| return groups[rank // group_size] | |
| def synchronize(): | |
| """ | |
| Overview: | |
| Synchronize the process | |
| """ | |
| get_link().synchronize() | |