Spaces:
Sleeping
Sleeping
| from typing import Callable, Tuple, List, Any, Union | |
| from easydict import EasyDict | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from .default_helper import error_wrapper | |
| # from .slurm_helper import get_master_addr | |
| def get_rank() -> int: | |
| """ | |
| Overview: | |
| Get the rank of current process in total world_size | |
| """ | |
| # return int(os.environ.get('SLURM_PROCID', 0)) | |
| return error_wrapper(dist.get_rank, 0)() | |
| def get_world_size() -> int: | |
| """ | |
| Overview: | |
| Get the world_size(total process number in data parallel training) | |
| """ | |
| # return int(os.environ.get('SLURM_NTASKS', 1)) | |
| return error_wrapper(dist.get_world_size, 1)() | |
| broadcast = dist.broadcast | |
| allgather = dist.all_gather | |
| broadcast_object_list = dist.broadcast_object_list | |
| def allreduce(x: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| All reduce the tensor ``x`` in the world | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): the tensor to be reduced | |
| """ | |
| dist.all_reduce(x) | |
| x.div_(get_world_size()) | |
| def allreduce_async(name: str, x: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| All reduce the tensor ``x`` in the world asynchronously | |
| Arguments: | |
| - name (:obj:`str`): the name of the tensor | |
| - x (:obj:`torch.Tensor`): the tensor to be reduced | |
| """ | |
| x.div_(get_world_size()) | |
| dist.all_reduce(x, async_op=True) | |
| def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]: | |
| """ | |
| Overview: | |
| Reduce the tensor ``x`` to the destination process ``dst`` | |
| Arguments: | |
| - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced | |
| - dst (:obj:`int`): the destination process | |
| """ | |
| if np.isscalar(x): | |
| x_tensor = torch.as_tensor([x]).cuda() | |
| dist.reduce(x_tensor, dst) | |
| return x_tensor.item() | |
| elif isinstance(x, torch.Tensor): | |
| dist.reduce(x, dst) | |
| return x | |
| else: | |
| raise TypeError("not supported type: {}".format(type(x))) | |
| def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]: | |
| """ | |
| Overview: | |
| All reduce the tensor ``x`` in the world | |
| Arguments: | |
| - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced | |
| - op (:obj:`str`): the operation to perform on data, support ``['sum', 'avg']`` | |
| """ | |
| assert op in ['sum', 'avg'], op | |
| if np.isscalar(x): | |
| x_tensor = torch.as_tensor([x]).cuda() | |
| dist.all_reduce(x_tensor) | |
| if op == 'avg': | |
| x_tensor.div_(get_world_size()) | |
| return x_tensor.item() | |
| elif isinstance(x, torch.Tensor): | |
| dist.all_reduce(x) | |
| if op == 'avg': | |
| x.div_(get_world_size()) | |
| return x | |
| else: | |
| raise TypeError("not supported type: {}".format(type(x))) | |
| synchronize = torch.cuda.synchronize | |
| def get_group(group_size: int) -> List: | |
| """ | |
| Overview: | |
| Get the group segmentation of ``group_size`` each group | |
| Arguments: | |
| - group_size (:obj:`int`) the ``group_size`` | |
| """ | |
| rank = get_rank() | |
| world_size = get_world_size() | |
| if group_size is None: | |
| group_size = world_size | |
| assert (world_size % group_size == 0) | |
| return simple_group_split(world_size, rank, world_size // group_size) | |
| def dist_mode(func: Callable) -> Callable: | |
| """ | |
| Overview: | |
| Wrap the function so that in can init and finalize automatically before each call | |
| Arguments: | |
| - func (:obj:`Callable`): the function to be wrapped | |
| """ | |
| def wrapper(*args, **kwargs): | |
| dist_init() | |
| func(*args, **kwargs) | |
| dist_finalize() | |
| return wrapper | |
| def dist_init(backend: str = 'nccl', | |
| addr: str = None, | |
| port: str = None, | |
| rank: int = None, | |
| world_size: int = None) -> Tuple[int, int]: | |
| """ | |
| Overview: | |
| Initialize the distributed training setting | |
| Arguments: | |
| - backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']`` | |
| - addr (:obj:`str`): The address of the master node | |
| - port (:obj:`str`): The port of the master node | |
| - rank (:obj:`int`): The rank of current process | |
| - world_size (:obj:`int`): The total number of processes | |
| """ | |
| assert backend in ['nccl', 'gloo'], backend | |
| os.environ['MASTER_ADDR'] = addr or os.environ.get('MASTER_ADDR', "localhost") | |
| os.environ['MASTER_PORT'] = port or os.environ.get('MASTER_PORT', "10314") # hard-code | |
| if rank is None: | |
| local_id = os.environ.get('SLURM_LOCALID', os.environ.get('RANK', None)) | |
| if local_id is None: | |
| raise RuntimeError("please indicate rank explicitly in dist_init method") | |
| else: | |
| rank = int(local_id) | |
| if world_size is None: | |
| ntasks = os.environ.get('SLURM_NTASKS', os.environ.get('WORLD_SIZE', None)) | |
| if ntasks is None: | |
| raise RuntimeError("please indicate world_size explicitly in dist_init method") | |
| else: | |
| world_size = int(ntasks) | |
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | |
| num_gpus = torch.cuda.device_count() | |
| torch.cuda.set_device(rank % num_gpus) | |
| world_size = get_world_size() | |
| rank = get_rank() | |
| return rank, world_size | |
| def dist_finalize() -> None: | |
| """ | |
| Overview: | |
| Finalize distributed training resources | |
| """ | |
| # This operation usually hangs out so we ignore it temporally. | |
| # dist.destroy_process_group() | |
| pass | |
| class DDPContext: | |
| """ | |
| Overview: | |
| A context manager for ``linklink`` distribution | |
| Interfaces: | |
| ``__init__``, ``__enter__``, ``__exit__`` | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the ``DDPContext`` | |
| """ | |
| pass | |
| def __enter__(self) -> None: | |
| """ | |
| Overview: | |
| Initialize ``linklink`` distribution | |
| """ | |
| dist_init() | |
| def __exit__(self, *args, **kwargs) -> Any: | |
| """ | |
| Overview: | |
| Finalize ``linklink`` distribution | |
| """ | |
| dist_finalize() | |
| def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: | |
| """ | |
| Overview: | |
| Split the group according to ``worldsize``, ``rank`` and ``num_groups`` | |
| Arguments: | |
| - world_size (:obj:`int`): The world size | |
| - rank (:obj:`int`): The rank | |
| - num_groups (:obj:`int`): The number of groups | |
| .. note:: | |
| With faulty input, raise ``array split does not result in an equal division`` | |
| """ | |
| groups = [] | |
| rank_list = np.split(np.arange(world_size), num_groups) | |
| rank_list = [list(map(int, x)) for x in rank_list] | |
| for i in range(num_groups): | |
| groups.append(dist.new_group(rank_list[i])) | |
| group_size = world_size // num_groups | |
| return groups[rank // group_size] | |
| def to_ddp_config(cfg: EasyDict) -> EasyDict: | |
| """ | |
| Overview: | |
| Convert the config to ddp config | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): The config to be converted | |
| """ | |
| w = get_world_size() | |
| if 'batch_size' in cfg.policy: | |
| cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w)) | |
| if 'batch_size' in cfg.policy.learn: | |
| cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w)) | |
| if 'n_sample' in cfg.policy.collect: | |
| cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w)) | |
| if 'n_episode' in cfg.policy.collect: | |
| cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w)) | |
| return cfg | |