| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| from contextlib import contextmanager |
| from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union |
|
|
| import torch |
| from torch import distributed as dist |
|
|
|
|
| if TYPE_CHECKING: |
| from torch.distributed import ProcessGroup |
|
|
|
|
| def all_gather(tensor: "torch.Tensor", world_size: int) -> "torch.Tensor": |
| """ |
| Gathers the tensor from all ranks and concats them along the first dim. |
| """ |
| output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device="cuda") |
| dist.all_gather_into_tensor(output_tensor, tensor) |
| return output_tensor.view(-1, *tensor.size()[1:]) |
|
|
|
|
| def all_reduce( |
| data: Union[int, float, List[Union[int, float]], "torch.Tensor"], |
| op: Literal["mean", "sum", "max"] = "mean", |
| group: Optional["ProcessGroup"] = None, |
| ) -> Union[int, float, List[Union[int, float]]]: |
| """ |
| Performs all reduce in the given process group. |
| """ |
| if not isinstance(data, torch.Tensor): |
| data = torch.tensor(data, dtype=torch.float, device="cuda") |
|
|
| reduce_ops = {"mean": dist.ReduceOp.SUM, "sum": dist.ReduceOp.SUM, "max": dist.ReduceOp.MAX} |
| dist.all_reduce(data, op=reduce_ops[op], group=group) |
| if op == "mean": |
| data /= dist.get_world_size(group=group) |
|
|
| if data.numel() == 1: |
| return data.item() |
| else: |
| return data.tolist() |
|
|
|
|
| @contextmanager |
| def main_process_first(local_only: bool = True) -> None: |
| """ |
| A context manager for torch distributed environment to do something on the main process firstly. |
| """ |
| if int(os.getenv("WORLD_SIZE", "1")) > 1: |
| is_main_process = int(os.getenv("LOCAL_RANK")) == 0 if local_only else int(os.getenv("RANK")) == 0 |
| try: |
| if not is_main_process: |
| dist.barrier() |
| yield |
| finally: |
| if is_main_process: |
| dist.barrier() |
| else: |
| yield |
|
|
|
|
| def execute_in_order(task: Callable, *, local_only: bool = True, **kwargs) -> Any: |
| """ |
| Executes the task in the order of rank. |
| """ |
| world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1") if local_only else os.getenv("WORLD_SIZE", "1")) |
| rank = int(os.getenv("LOCAL_RANK", "1") if local_only else os.getenv("RANK", "1")) |
| if world_size > 1: |
| dist.barrier() |
| for i in range(world_size): |
| if rank == i: |
| result = task(**kwargs) |
| dist.barrier() |
| else: |
| dist.barrier() |
|
|
| return result |
| else: |
| return task(**kwargs) |
|
|