| # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py | |
| # Copyright 2023 The vLLM team. | |
| # Adapted from | |
| # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py | |
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| """Distributed state. | |
| It takes over the control of the distributed environment from PyTorch. | |
| The typical workflow is: | |
| - call `init_distributed_environment` to initialize the distributed environment. | |
| - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to | |
| initialize the model parallel groups. | |
| - any code dealing with the distributed stuff | |
| - call `destroy_model_parallel` to destroy the model parallel groups. | |
| - call `destroy_distributed_environment` to destroy the distributed environment. | |
| If you only need to use the distributed environment without model/pipeline | |
| parallelism, you can skip the model parallel initialization and destruction | |
| steps. | |
| """ | |
| import contextlib | |
| import gc | |
| import logging | |
| import os | |
| import pickle | |
| import weakref | |
| from collections import namedtuple | |
| from contextlib import contextmanager, nullcontext | |
| from dataclasses import dataclass | |
| from datetime import timedelta | |
| from multiprocessing import shared_memory | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| from unittest.mock import patch | |
| import torch | |
| import torch.distributed | |
| from torch.distributed import Backend, ProcessGroup | |
| from sglang.srt.environ import envs | |
| from sglang.srt.utils import ( | |
| direct_register_custom_op, | |
| get_bool_env_var, | |
| get_int_env_var, | |
| get_local_ip_auto, | |
| is_cpu, | |
| is_cuda_alike, | |
| is_hip, | |
| is_npu, | |
| is_shm_available, | |
| is_xpu, | |
| supports_custom_op, | |
| ) | |
| _is_npu = is_npu() | |
| _is_cpu = is_cpu() | |
| _is_xpu = is_xpu() | |
| _supports_custom_op = supports_custom_op() | |
| TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) | |
| # use int value instead of ReduceOp.SUM to support torch compile | |
| REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) | |
| class GraphCaptureContext: | |
| stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream | |
| class P2PWork: | |
| work: Optional[torch.distributed.Work] | |
| payload: Optional[torch.Tensor] | |
| def _split_tensor_dict( | |
| tensor_dict: Dict[str, Union[torch.Tensor, Any]] | |
| ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: | |
| """Split the tensor dictionary into two parts: | |
| 1. A list of (key, value) pairs. If the value is a tensor, it is replaced | |
| by its metadata. | |
| 2. A list of tensors. | |
| """ | |
| metadata_list: List[Tuple[str, Any]] = [] | |
| tensor_list: List[torch.Tensor] = [] | |
| for key, value in tensor_dict.items(): | |
| if isinstance(value, torch.Tensor): | |
| # Note: we cannot use `value.device` here, | |
| # because it contains not only the device type but also the device | |
| # index (e.g. "cuda:0"). We only need the device type. | |
| # receiving side will set the device index. | |
| device = value.device.type | |
| metadata_list.append( | |
| (key, TensorMetadata(device, value.dtype, value.size())) | |
| ) | |
| tensor_list.append(value) | |
| else: | |
| metadata_list.append((key, value)) | |
| return metadata_list, tensor_list | |
| _group_name_counter: Dict[str, int] = {} | |
| def _get_unique_name(name: str) -> str: | |
| """Get a unique name for the group. | |
| Example: | |
| _get_unique_name("tp") -> "tp:0" | |
| _get_unique_name("tp") -> "tp:1" | |
| """ | |
| if name not in _group_name_counter: | |
| _group_name_counter[name] = 0 | |
| newname = f"{name}:{_group_name_counter[name]}" | |
| _group_name_counter[name] += 1 | |
| return newname | |
| _groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} | |
| def _register_group(group: "GroupCoordinator") -> None: | |
| _groups[group.unique_name] = weakref.ref(group) | |
| if _supports_custom_op: | |
| def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: | |
| assert group_name in _groups, f"Group {group_name} is not found." | |
| group = _groups[group_name]() | |
| if group is None: | |
| raise ValueError(f"Group {group_name} is destroyed.") | |
| group._all_reduce_in_place(tensor) | |
| def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: | |
| return | |
| direct_register_custom_op( | |
| op_name="inplace_all_reduce", | |
| op_func=inplace_all_reduce, | |
| mutates_args=["tensor"], | |
| fake_impl=inplace_all_reduce_fake, | |
| ) | |
| def outplace_all_reduce( | |
| tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str | |
| ) -> torch.Tensor: | |
| assert group_name in _groups, f"Group {group_name} is not found." | |
| group = _groups[group_name]() | |
| if group is None: | |
| raise ValueError(f"Group {group_name} is destroyed.") | |
| return group._all_reduce_out_place(tensor, outplace_all_reduce_method) | |
| def outplace_all_reduce_fake( | |
| tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str | |
| ) -> torch.Tensor: | |
| return torch.empty_like(tensor) | |
| direct_register_custom_op( | |
| op_name="outplace_all_reduce", | |
| op_func=outplace_all_reduce, | |
| mutates_args=[], | |
| fake_impl=outplace_all_reduce_fake, | |
| ) | |
| def reg_all_gather_into_tensor( | |
| output: torch.Tensor, input: torch.Tensor, group_name: str | |
| ) -> None: | |
| assert group_name in _groups, f"Group {group_name} is not found." | |
| group = _groups[group_name]() | |
| if group is None: | |
| raise ValueError(f"Group {group_name} is destroyed.") | |
| group._all_gather_into_tensor(output, input) | |
| def reg_all_gather_into_tensor_fake( | |
| output: torch.Tensor, input: torch.Tensor, group_name: str | |
| ) -> None: | |
| pass | |
| direct_register_custom_op( | |
| op_name="reg_all_gather_into_tensor", | |
| op_func=reg_all_gather_into_tensor, | |
| mutates_args=["output"], | |
| fake_impl=reg_all_gather_into_tensor_fake, | |
| ) | |
| class GroupCoordinator: | |
| """ | |
| PyTorch ProcessGroup wrapper for a group of processes. | |
| PyTorch ProcessGroup is bound to one specific communication backend, | |
| e.g. NCCL, Gloo, MPI, etc. | |
| GroupCoordinator takes charge of all the communication operations among | |
| the processes in the group. It can route the communication to | |
| a specific implementation (e.g. switch allreduce implementation | |
| based on the tensor size and cuda graph mode). | |
| """ | |
| # available attributes: | |
| rank: int # global rank | |
| ranks: List[int] # global ranks in the group | |
| world_size: int # size of the group | |
| # difference between `local_rank` and `rank_in_group`: | |
| # if we have a group of size 4 across two nodes: | |
| # Process | Node | Rank | Local Rank | Rank in Group | |
| # 0 | 0 | 0 | 0 | 0 | |
| # 1 | 0 | 1 | 1 | 1 | |
| # 2 | 1 | 2 | 0 | 2 | |
| # 3 | 1 | 3 | 1 | 3 | |
| local_rank: int # local rank used to assign devices | |
| rank_in_group: int # rank inside the group | |
| cpu_group: ProcessGroup # group for CPU communication | |
| device_group: ProcessGroup # group for device communication | |
| use_pynccl: bool # a hint of whether to use PyNccl | |
| use_pymscclpp: bool # a hint of whether to use PyMsccl | |
| use_custom_allreduce: bool # a hint of whether to use CustomAllreduce | |
| use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce | |
| use_message_queue_broadcaster: ( | |
| bool # a hint of whether to use message queue broadcaster | |
| ) | |
| # communicators are only created for world size > 1 | |
| pynccl_comm: Optional[Any] # PyNccl communicator | |
| ca_comm: Optional[Any] # Custom allreduce communicator | |
| symm_mem_comm: Optional[Any] # Symm mem communicator | |
| mq_broadcaster: Optional[Any] # shared memory broadcaster | |
| def __init__( | |
| self, | |
| group_ranks: List[List[int]], | |
| local_rank: int, | |
| torch_distributed_backend: Union[str, Backend], | |
| use_pynccl: bool, | |
| use_pymscclpp: bool, | |
| use_custom_allreduce: bool, | |
| use_torch_symm_mem: bool, | |
| use_hpu_communicator: bool, | |
| use_xpu_communicator: bool, | |
| use_npu_communicator: bool, | |
| use_message_queue_broadcaster: bool = False, | |
| group_name: Optional[str] = None, | |
| pynccl_use_current_stream: bool = False, | |
| torch_compile: Optional[bool] = None, | |
| gloo_timeout: timedelta = timedelta(seconds=120 * 60), | |
| ): | |
| # Set group info | |
| group_name = group_name or "anonymous" | |
| self.unique_name = _get_unique_name(group_name) | |
| _register_group(self) | |
| # Set rank info | |
| self.rank = torch.distributed.get_rank() | |
| self.local_rank = local_rank | |
| self.device_group = None | |
| self.cpu_group = None | |
| self.local_size = get_int_env_var("LOCAL_SIZE", 0) | |
| for ranks in group_ranks: | |
| device_group = torch.distributed.new_group( | |
| ranks, backend=torch_distributed_backend | |
| ) | |
| # a cpu_group to allow direct coordination between processes through | |
| # the CPU. The backend is chosen based on `torch_distributed_backend` | |
| if "mooncake" in torch_distributed_backend: | |
| cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") | |
| else: | |
| cpu_group = torch.distributed.new_group( | |
| ranks, backend="gloo", timeout=gloo_timeout | |
| ) | |
| if self.rank in ranks: | |
| self.ranks = ranks | |
| self.world_size = len(ranks) | |
| self.rank_in_group = ranks.index(self.rank) | |
| self.device_group = device_group | |
| self.cpu_group = cpu_group | |
| assert self.cpu_group is not None | |
| assert self.device_group is not None | |
| if is_cuda_alike(): | |
| device_id = ( | |
| 0 if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() else local_rank | |
| ) | |
| self.device = torch.device(f"cuda:{device_id}") | |
| elif _is_npu: | |
| self.device = torch.device(f"npu:{local_rank}") | |
| else: | |
| self.device = torch.device("cpu") | |
| self.device_module = torch.get_device_module(self.device) | |
| # Import communicators | |
| self.use_pynccl = use_pynccl | |
| self.pynccl_use_current_stream = pynccl_use_current_stream | |
| self.use_pymscclpp = use_pymscclpp | |
| self.use_custom_allreduce = use_custom_allreduce | |
| self.use_torch_symm_mem = use_torch_symm_mem | |
| self.use_hpu_communicator = use_hpu_communicator | |
| self.use_xpu_communicator = use_xpu_communicator | |
| self.use_npu_communicator = use_npu_communicator | |
| self.use_message_queue_broadcaster = use_message_queue_broadcaster | |
| # Lazy import to avoid documentation build error | |
| from sglang.srt.distributed.device_communicators.custom_all_reduce import ( | |
| CustomAllreduce, | |
| ) | |
| from sglang.srt.distributed.device_communicators.pymscclpp import ( | |
| PyMscclppCommunicator, | |
| ) | |
| from sglang.srt.distributed.device_communicators.pynccl import ( | |
| PyNcclCommunicator, | |
| ) | |
| from sglang.srt.distributed.device_communicators.symm_mem import ( | |
| SymmMemCommunicator, | |
| ) | |
| if is_hip(): | |
| from sglang.srt.distributed.device_communicators.quick_all_reduce import ( | |
| QuickAllReduce, | |
| qr_rocm_arch_available, | |
| ) | |
| self.pynccl_comm: Optional[PyNcclCommunicator] = None | |
| if use_pynccl and self.world_size > 1: | |
| self.pynccl_comm = PyNcclCommunicator( | |
| group=self.cpu_group, | |
| device=self.device, | |
| use_current_stream=pynccl_use_current_stream, | |
| ) | |
| self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None | |
| if use_pymscclpp and self.world_size > 1: | |
| self.pymscclpp_comm = PyMscclppCommunicator( | |
| group=self.cpu_group, | |
| device=self.device, | |
| ) | |
| self.ca_comm: Optional[CustomAllreduce] = None | |
| self.qr_comm: Optional[QuickAllReduce] = None | |
| if use_custom_allreduce and self.world_size > 1: | |
| # Initialize a custom fast all-reduce implementation. | |
| if torch_compile is not None and torch_compile: | |
| # For piecewise CUDA graph, the requirement for custom allreduce is larger to | |
| # avoid illegal cuda memory access. | |
| ca_max_size = 256 * 1024 * 1024 | |
| else: | |
| ca_max_size = 8 * 1024 * 1024 | |
| try: | |
| self.ca_comm = CustomAllreduce( | |
| group=self.cpu_group, | |
| device=self.device, | |
| max_size=ca_max_size, | |
| ) | |
| except Exception as e: | |
| logger.warning( | |
| f"Setup Custom allreduce failed with {e}. To silence this " | |
| "warning, specify --disable-custom-all-reduce explicitly." | |
| ) | |
| if is_hip(): | |
| try: | |
| # Initialize a custom quick all-reduce implementation for AMD | |
| # when rocm >= gfx942. Quick reduce is designed as a | |
| # complement to custom allreduce. | |
| # Based on quickreduce (https://github.com/mk1-project/quickreduce). | |
| if qr_rocm_arch_available(): | |
| self.qr_comm = QuickAllReduce( | |
| group=self.cpu_group, device=self.device | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize QuickAllReduce: {e}") | |
| self.symm_mem_comm: Optional[SymmMemCommunicator] = None | |
| if self.use_torch_symm_mem and self.world_size > 1: | |
| self.symm_mem_comm = SymmMemCommunicator( | |
| group=self.cpu_group, | |
| device=self.device, | |
| ) | |
| # Create communicator for other hardware backends | |
| from sglang.srt.distributed.device_communicators.hpu_communicator import ( | |
| HpuCommunicator, | |
| ) | |
| from sglang.srt.distributed.device_communicators.npu_communicator import ( | |
| NpuCommunicator, | |
| ) | |
| from sglang.srt.distributed.device_communicators.xpu_communicator import ( | |
| XpuCommunicator, | |
| ) | |
| self.hpu_communicator: Optional[HpuCommunicator] = None | |
| if use_hpu_communicator and self.world_size > 1: | |
| self.hpu_communicator = HpuCommunicator(group=self.device_group) | |
| self.xpu_communicator: Optional[XpuCommunicator] = None | |
| if use_xpu_communicator and self.world_size > 1: | |
| self.xpu_communicator = XpuCommunicator(group=self.device_group) | |
| self.npu_communicator: Optional[NpuCommunicator] = None | |
| if use_npu_communicator and self.world_size > 1: | |
| self.npu_communicator = NpuCommunicator(group=self.device_group) | |
| # Create message queue | |
| from sglang.srt.distributed.device_communicators.shm_broadcast import ( | |
| MessageQueue, | |
| ) | |
| self.mq_broadcaster: Optional[MessageQueue] = None | |
| if use_message_queue_broadcaster and self.world_size > 1: | |
| self.mq_broadcaster = MessageQueue.create_from_process_group( | |
| self.cpu_group, 1 << 22, 6 | |
| ) | |
| def __repr__(self): | |
| return ( | |
| f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} " | |
| f"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} " | |
| f"world_size={self.world_size} rank_in_group={self.rank_in_group}" | |
| ) | |
| def first_rank(self): | |
| """Return the global rank of the first process in the group""" | |
| return self.ranks[0] | |
| def last_rank(self): | |
| """Return the global rank of the last process in the group""" | |
| return self.ranks[-1] | |
| def is_first_rank(self): | |
| """Return whether the caller is the first process in the group""" | |
| return self.rank == self.first_rank | |
| def is_last_rank(self): | |
| """Return whether the caller is the last process in the group""" | |
| return self.rank == self.last_rank | |
| def next_rank(self): | |
| """Return the global rank of the process that follows the caller""" | |
| rank_in_group = self.rank_in_group | |
| world_size = self.world_size | |
| return self.ranks[(rank_in_group + 1) % world_size] | |
| def prev_rank(self): | |
| """Return the global rank of the process that precedes the caller""" | |
| rank_in_group = self.rank_in_group | |
| world_size = self.world_size | |
| return self.ranks[(rank_in_group - 1) % world_size] | |
| def graph_capture( | |
| self, | |
| graph_capture_context: Optional[GraphCaptureContext] = None, | |
| stream: Optional[torch.cuda.Stream] = None, | |
| ): | |
| if graph_capture_context is None: | |
| if stream is None: | |
| stream = self.device_module.Stream() | |
| graph_capture_context = GraphCaptureContext(stream) | |
| else: | |
| stream = graph_capture_context.stream | |
| # We don't need the context of custom quick allreduce because the ipc access | |
| # is already collected in init() and we can capture the quick allreduce directly. | |
| ca_comm = self.ca_comm | |
| maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() | |
| # ensure all initialization operations complete before attempting to | |
| # capture the graph on another stream | |
| curr_stream = self.device_module.current_stream() | |
| if curr_stream != stream: | |
| stream.wait_stream(curr_stream) | |
| with self.device_module.stream(stream), maybe_ca_context: | |
| # In graph mode, we have to be very careful about the collective | |
| # operations. The current status is: | |
| # allreduce \ Mode | Eager | Graph | | |
| # -------------------------------------------- | |
| # quick allreduce | enabled | enabled | | |
| # custom allreduce | enabled | enabled | | |
| # PyNccl | disabled| enabled | | |
| # PyMscclpp | disabled| enabled | | |
| # TorchSymmMem | disabled| enabled | | |
| # torch.distributed | enabled | disabled| | |
| # | |
| # Note: When custom quick allreduce is enabled, a runtime check | |
| # will be performed. If the tensor size is too small, it will | |
| # automatically fall back to the next available option. | |
| # Note that custom allreduce will have a runtime check, if the | |
| # tensor size is too large, it will fallback to the next | |
| # available option. | |
| # Note that the PyMsccl needs to register the tensor in ahead, | |
| # which will introduce large overhead in the eager case, | |
| # therefore it is only supported in the graph case. | |
| # In summary: We select the appropriate allreduce method for | |
| # each mode based on the algorithm order in the table and | |
| # their usage conditions. | |
| pynccl_comm = self.pynccl_comm | |
| maybe_pynccl_context: Any | |
| if not pynccl_comm: | |
| maybe_pynccl_context = nullcontext() | |
| else: | |
| maybe_pynccl_context = pynccl_comm.change_state( | |
| enable=True, stream=torch.cuda.current_stream() | |
| ) | |
| pymscclpp_comm = self.pymscclpp_comm | |
| maybe_pymscclpp_context: Any | |
| if not pymscclpp_comm: | |
| maybe_pymscclpp_context = nullcontext() | |
| else: | |
| maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True) | |
| with maybe_pynccl_context, maybe_pymscclpp_context: | |
| yield graph_capture_context | |
| def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: | |
| """ | |
| User-facing all-reduce function before we actually call the | |
| all-reduce operation. | |
| We need this because Dynamo does not support passing an arbitrary | |
| object (`self` in this case) to a custom op. We need to pass the | |
| group name as a string, and then look up the group coordinator from | |
| the group name, dispatch the all-reduce operation to the group | |
| coordinator. | |
| In addition, PyTorch custom ops do not support mutation or returning | |
| a new tensor in the same op. So we need to figure out if the op is | |
| in-place or out-of-place ahead of time. | |
| """ | |
| # Bypass the function if we are using only 1 GPU. | |
| if self.world_size == 1: | |
| return input_ | |
| if input_.is_cpu: | |
| if is_shm_available(input_.dtype, self.world_size, self.local_size): | |
| torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM) | |
| else: | |
| torch.distributed.all_reduce(input_, group=self.device_group) | |
| return input_ | |
| if not _supports_custom_op: | |
| self._all_reduce_in_place(input_) | |
| return input_ | |
| if self.hpu_communicator is not None and not self.hpu_communicator.disabled: | |
| return self.hpu_communicator.all_reduce(input_) | |
| if self.xpu_communicator is not None and not self.xpu_communicator.disabled: | |
| return self.xpu_communicator.all_reduce(input_) | |
| if self.npu_communicator is not None and not self.npu_communicator.disabled: | |
| return self.npu_communicator.all_reduce(input_) | |
| if ( | |
| self.pynccl_comm is not None | |
| and hasattr(input_, "symmetric_memory") | |
| and input_.symmetric_memory | |
| ): | |
| with self.pynccl_comm.change_state( | |
| enable=True, stream=torch.cuda.current_stream() | |
| ): | |
| self.pynccl_comm.all_reduce(input_) | |
| return input_ | |
| outplace_all_reduce_method = None | |
| if ( | |
| self.ca_comm is not None | |
| and not self.ca_comm.disabled | |
| and self.ca_comm.should_custom_ar(input_) | |
| ): | |
| outplace_all_reduce_method = "ca" | |
| elif ( | |
| self.qr_comm is not None | |
| and not self.qr_comm.disabled | |
| and self.qr_comm.should_quick_allreduce(input_) | |
| ): | |
| outplace_all_reduce_method = "qr" | |
| elif ( | |
| self.pymscclpp_comm is not None | |
| and not self.pymscclpp_comm.disabled | |
| and self.pymscclpp_comm.should_mscclpp_allreduce(input_) | |
| ): | |
| outplace_all_reduce_method = "pymscclpp" | |
| elif ( | |
| self.symm_mem_comm is not None | |
| and not self.symm_mem_comm.disabled | |
| and self.symm_mem_comm.should_symm_mem_allreduce(input_) | |
| ): | |
| outplace_all_reduce_method = "symm_mem" | |
| if outplace_all_reduce_method is not None: | |
| return torch.ops.sglang.outplace_all_reduce( | |
| input_, | |
| group_name=self.unique_name, | |
| outplace_all_reduce_method=outplace_all_reduce_method, | |
| ) | |
| else: | |
| torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name) | |
| return input_ | |
| def _all_reduce_out_place( | |
| self, input_: torch.Tensor, outplace_all_reduce_method: str | |
| ) -> torch.Tensor: | |
| ca_comm = self.ca_comm | |
| qr_comm = self.qr_comm | |
| pymscclpp_comm = self.pymscclpp_comm | |
| symm_mem_comm = self.symm_mem_comm | |
| assert any([qr_comm, ca_comm, pymscclpp_comm]) | |
| if outplace_all_reduce_method == "ca": | |
| assert not ca_comm.disabled | |
| out = ca_comm.custom_all_reduce(input_) | |
| elif outplace_all_reduce_method == "qr": | |
| assert not qr_comm.disabled | |
| out = qr_comm.quick_all_reduce(input_) | |
| elif outplace_all_reduce_method == "symm_mem": | |
| assert not symm_mem_comm.disabled | |
| out = symm_mem_comm.all_reduce(input_) | |
| else: | |
| assert not pymscclpp_comm.disabled | |
| out = pymscclpp_comm.all_reduce(input_) | |
| assert out is not None | |
| return out | |
| def _all_reduce_in_place(self, input_: torch.Tensor) -> None: | |
| pynccl_comm = self.pynccl_comm | |
| symm_mem_comm = self.symm_mem_comm | |
| if pynccl_comm is not None and not pynccl_comm.disabled: | |
| pynccl_comm.all_reduce(input_) | |
| elif symm_mem_comm is not None and not symm_mem_comm.disabled: | |
| symm_mem_comm.all_reduce(input_) | |
| else: | |
| torch.distributed.all_reduce(input_, group=self.device_group) | |
| def reduce_scatter_tensor( | |
| self, | |
| output: torch.Tensor, | |
| input: torch.Tensor, | |
| ) -> None: | |
| # TODO(ch-wan): support other backends | |
| torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) | |
| return output | |
| def reduce_scatter( | |
| self, | |
| output: torch.Tensor, | |
| input_list: List[torch.Tensor], | |
| ) -> None: | |
| # TODO(ch-wan): support other backends | |
| torch.distributed.reduce_scatter(output, input_list, group=self.device_group) | |
| return output | |
| def reduce_scatterv( | |
| self, | |
| input_: torch.Tensor, | |
| output: Optional[torch.Tensor] = None, | |
| sizes: Optional[List[int]] = None, | |
| ) -> torch.Tensor: | |
| world_size = self.world_size | |
| pynccl_comm = self.pynccl_comm | |
| with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): | |
| assert ( | |
| pynccl_comm is not None and not pynccl_comm.disabled | |
| ), "pynccl is required for reduce_scatterv" | |
| if sizes is not None: | |
| assert len(sizes) == world_size | |
| assert input_.shape[0] == sum(sizes) | |
| chunk_size = sizes[self.rank_in_group] | |
| else: | |
| assert input_.shape[0] % world_size == 0 | |
| chunk_size = input_.shape[0] // world_size | |
| output_shape = (chunk_size,) + input_.shape[1:] | |
| if output is None: | |
| output = torch.empty( | |
| output_shape, dtype=input_.dtype, device=input_.device | |
| ) | |
| else: | |
| assert output.shape == output_shape | |
| pynccl_comm.reduce_scatter(output, input_, sizes=sizes) | |
| return output | |
| def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): | |
| pynccl_comm = self.pynccl_comm | |
| if pynccl_comm is not None and not pynccl_comm.disabled: | |
| pynccl_comm.all_gather(output, input) | |
| else: | |
| torch.distributed.all_gather_into_tensor( | |
| output, input, group=self.device_group | |
| ) | |
| def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): | |
| if _is_npu or _is_xpu or not _supports_custom_op: | |
| self._all_gather_into_tensor(output, input) | |
| else: | |
| torch.ops.sglang.reg_all_gather_into_tensor( | |
| output, input, group_name=self.unique_name | |
| ) | |
| def all_gather( | |
| self, | |
| input_: torch.Tensor, | |
| dim: int = -1, | |
| output_tensor_list: Optional[List[torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| world_size = self.world_size | |
| # Bypass the function if we are using only 1 GPU. | |
| if world_size == 1: | |
| if output_tensor_list is not None: | |
| logger.warning( | |
| "Performing in-place all-gather with a group size of 1. " | |
| "This may be unnecessary; consider bypassing it for better efficiency." | |
| ) | |
| output_tensor_list[0].copy_(input_) | |
| return None | |
| else: | |
| return input_ | |
| if output_tensor_list is not None: | |
| # TODO(ch-wan): support other backends | |
| return torch.distributed.all_gather( | |
| output_tensor_list, input_, group=self.device_group | |
| ) | |
| assert ( | |
| -input_.dim() <= dim < input_.dim() | |
| ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" | |
| # For HPUs, use HPU communicator. | |
| hpu_comm = self.hpu_communicator | |
| if hpu_comm is not None and not hpu_comm.disabled: | |
| return hpu_comm.all_gather(input_, dim) | |
| # For NPUs, use NPU communicator. | |
| npu_comm = self.npu_communicator | |
| if npu_comm is not None and not npu_comm.disabled: | |
| return npu_comm.all_gather(input_, dim) | |
| if dim < 0: | |
| # Convert negative dim to positive. | |
| dim += input_.dim() | |
| input_size = input_.size() | |
| # NOTE: we have to use concat-style all-gather here, | |
| # stack-style all-gather has compatibility issues with | |
| # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 | |
| output_size = (input_size[0] * world_size,) + input_size[1:] | |
| # Allocate output tensor. | |
| output_tensor = torch.empty( | |
| output_size, dtype=input_.dtype, device=input_.device | |
| ) | |
| # All-gather. | |
| if input_.is_cpu: | |
| if is_shm_available(input_.dtype, self.world_size, self.local_size): | |
| return torch.ops.sgl_kernel.shm_allgather(input_, dim) | |
| else: | |
| torch.distributed.all_gather_into_tensor( | |
| output_tensor, input_, group=self.device_group | |
| ) | |
| else: | |
| self.all_gather_into_tensor(output_tensor, input_) | |
| # Reshape | |
| output_tensor = output_tensor.reshape((world_size,) + input_size) | |
| output_tensor = output_tensor.movedim(0, dim) | |
| output_tensor = output_tensor.reshape( | |
| input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] | |
| ) | |
| return output_tensor | |
| def all_gatherv( | |
| self, | |
| input_: Union[torch.Tensor, List[torch.Tensor]], | |
| sizes: Optional[List[int]] = None, | |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: | |
| """ | |
| Supports varying sizes per rank and input tensor list. | |
| `sizes`: a list of len(world_size) with the number of items per rank to gather. | |
| """ | |
| world_size = self.world_size | |
| pynccl_comm = self.pynccl_comm | |
| with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): | |
| assert ( | |
| pynccl_comm is not None and not pynccl_comm.disabled | |
| ), "pynccl is required for all_gatherv" | |
| def _all_gather_single( | |
| input_: torch.Tensor, sizes: Optional[List[int]] = None | |
| ): | |
| input_size = input_.size() | |
| if sizes is not None: | |
| assert len(sizes) == world_size | |
| assert input_.shape[0] == sizes[self.rank_in_group] | |
| output_size = (sum(sizes),) + input_size[1:] | |
| # 'sizes' is not needed if all inputs in the same group have the same shape | |
| if all(s == sizes[0] for s in sizes): | |
| sizes = None | |
| else: | |
| output_size = (input_size[0] * world_size,) + input_size[1:] | |
| # Allocate output tensor. | |
| output_tensor = torch.empty( | |
| output_size, dtype=input_.dtype, device=input_.device | |
| ) | |
| pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) | |
| return output_tensor | |
| if isinstance(input_, torch.Tensor): | |
| return _all_gather_single(input_, sizes) | |
| output_list = [] | |
| pynccl_comm.group_start() | |
| for inp in input_: | |
| output_list.append(_all_gather_single(inp, sizes=sizes)) | |
| pynccl_comm.group_end() | |
| return output_list | |
| def gather( | |
| self, input_: torch.Tensor, dst: int = 0, dim: int = -1 | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| NOTE: We assume that the input tensor is on the same device across | |
| all the ranks. | |
| NOTE: `dst` is the local rank of the destination rank. | |
| """ | |
| world_size = self.world_size | |
| # Bypass the function if we are using only 1 GPU. | |
| if world_size == 1: | |
| return input_ | |
| assert ( | |
| -input_.dim() <= dim < input_.dim() | |
| ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" | |
| if dim < 0: | |
| # Convert negative dim to positive. | |
| dim += input_.dim() | |
| if self.xpu_communicator is not None and not self.xpu_communicator.disabled: | |
| return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim) | |
| # Allocate output tensor. | |
| if self.rank_in_group == dst: | |
| gather_list = [torch.empty_like(input_) for _ in range(world_size)] | |
| else: | |
| gather_list = None | |
| # Gather. | |
| torch.distributed.gather( | |
| input_, gather_list, dst=self.ranks[dst], group=self.device_group | |
| ) | |
| if self.rank_in_group == dst: | |
| output_tensor = torch.cat(gather_list, dim=dim) | |
| else: | |
| output_tensor = None | |
| return output_tensor | |
| def broadcast(self, input_: torch.Tensor, src: int = 0): | |
| """Broadcast the input tensor. | |
| NOTE: `src` is the local rank of the source rank. | |
| """ | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| # Bypass the function if we are using only 1 GPU. | |
| if self.world_size == 1: | |
| return input_ | |
| # Broadcast. | |
| torch.distributed.broadcast( | |
| input_, src=self.ranks[src], group=self.device_group | |
| ) | |
| return input_ | |
| def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): | |
| """Broadcast the input object. | |
| NOTE: `src` is the local rank of the source rank. | |
| """ | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| # Bypass the function if we are using only 1 GPU. | |
| if self.world_size == 1: | |
| return obj | |
| if self.mq_broadcaster is not None: | |
| assert src == 0, "Message queue broadcaster only supports src=0" | |
| return self.mq_broadcaster.broadcast_object(obj) | |
| if self.rank_in_group == src: | |
| torch.distributed.broadcast_object_list( | |
| [obj], src=self.ranks[src], group=self.cpu_group | |
| ) | |
| return obj | |
| else: | |
| recv = [None] | |
| torch.distributed.broadcast_object_list( | |
| recv, src=self.ranks[src], group=self.cpu_group | |
| ) | |
| return recv[0] | |
| def broadcast_object_list( | |
| self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None | |
| ): | |
| """Broadcast the input object list. | |
| NOTE: `src` is the local rank of the source rank. | |
| """ | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| # Bypass the function if we are using only 1 GPU. | |
| if self.world_size == 1: | |
| return obj_list | |
| # Broadcast. | |
| torch.distributed.broadcast_object_list( | |
| obj_list, src=self.ranks[src], group=self.device_group | |
| ) | |
| return obj_list | |
| def all_gather_object(self, obj: Any) -> List[Any]: | |
| objs = [None] * self.world_size | |
| torch.distributed.all_gather_object(objs, obj, group=self.cpu_group) | |
| return objs | |
| def send_object( | |
| self, | |
| obj: Any, | |
| dst: int, | |
| async_send: bool = False, | |
| ) -> List[P2PWork]: | |
| """ | |
| Send the input object list to the destination rank. | |
| This function uses the CPU group for all communications. | |
| TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group), | |
| use other functions (e.g., send), or implement a new function (e.g., send_object_device). | |
| NOTE: `dst` is the local rank of the destination rank. | |
| """ | |
| assert dst < self.world_size, f"Invalid dst rank ({dst})" | |
| assert dst != self.rank_in_group, ( | |
| "Invalid destination rank. Destination rank is the same " | |
| "as the current rank." | |
| ) | |
| send_func = torch.distributed.isend if async_send else torch.distributed.send | |
| # Serialize object to tensor and get the size as well | |
| object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) | |
| size_tensor = torch.tensor( | |
| [object_tensor.numel()], dtype=torch.long, device="cpu" | |
| ) | |
| # Send object size | |
| p2p_work = [] | |
| size_work = send_func( | |
| size_tensor, | |
| self.ranks[dst], | |
| group=self.cpu_group, | |
| ) | |
| if async_send: | |
| p2p_work.append(P2PWork(size_work, size_tensor)) | |
| object_work = send_func( | |
| object_tensor, | |
| self.ranks[dst], | |
| group=self.cpu_group, | |
| ) | |
| if async_send: | |
| p2p_work.append(P2PWork(object_work, object_tensor)) | |
| return p2p_work | |
| def recv_object( | |
| self, | |
| src: int, | |
| ) -> Any: | |
| """Receive the input object list from the source rank.""" | |
| """NOTE: `src` is the local rank of the source rank.""" | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| assert ( | |
| src != self.rank_in_group | |
| ), "Invalid source rank. Source rank is the same as the current rank." | |
| size_tensor = torch.empty(1, dtype=torch.long, device="cpu") | |
| # Receive object size | |
| # We have to use irecv here to make it work for both isend and send. | |
| work = torch.distributed.irecv( | |
| size_tensor, src=self.ranks[src], group=self.cpu_group | |
| ) | |
| work.wait() | |
| # Tensor to receive serialized objects into. | |
| object_tensor: Any = torch.empty( # type: ignore[call-overload] | |
| size_tensor.item(), # type: ignore[arg-type] | |
| dtype=torch.uint8, | |
| device="cpu", | |
| ) | |
| work = torch.distributed.irecv( | |
| object_tensor, src=self.ranks[src], group=self.cpu_group | |
| ) | |
| work.wait() | |
| obj = pickle.loads(object_tensor.numpy()) | |
| return obj | |
| def broadcast_tensor_dict( | |
| self, | |
| tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, | |
| src: int = 0, | |
| group: Optional[ProcessGroup] = None, | |
| metadata_group: Optional[ProcessGroup] = None, | |
| ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: | |
| """Broadcast the input tensor dictionary. | |
| NOTE: `src` is the local rank of the source rank. | |
| """ | |
| # Bypass the function if we are using only 1 GPU. | |
| if not torch.distributed.is_initialized() or self.world_size == 1: | |
| return tensor_dict | |
| group = self.device_group | |
| metadata_group = self.cpu_group | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| rank_in_group = self.rank_in_group | |
| if rank_in_group == src: | |
| metadata_list: List[Tuple[Any, Any]] = [] | |
| assert isinstance( | |
| tensor_dict, dict | |
| ), f"Expecting a dictionary, got {type(tensor_dict)}" | |
| metadata_list, tensor_list = _split_tensor_dict(tensor_dict) | |
| # `metadata_list` lives in CPU memory. | |
| # `broadcast_object_list` has serialization & deserialization, | |
| # all happening on CPU. Therefore, we can use the CPU group. | |
| self.broadcast_object(metadata_list, src=src) | |
| async_handles = [] | |
| for tensor in tensor_list: | |
| if tensor.numel() == 0: | |
| # Skip broadcasting empty tensors. | |
| continue | |
| if tensor.is_cpu: | |
| # use metadata_group for CPU tensors | |
| handle = torch.distributed.broadcast( | |
| tensor, src=self.ranks[src], group=metadata_group, async_op=True | |
| ) | |
| else: | |
| # use group for GPU tensors | |
| handle = torch.distributed.broadcast( | |
| tensor, src=self.ranks[src], group=group, async_op=True | |
| ) | |
| async_handles.append(handle) | |
| for async_handle in async_handles: | |
| async_handle.wait() | |
| else: | |
| metadata_list = self.broadcast_object(None, src=src) | |
| tensor_dict = {} | |
| async_handles = [] | |
| for key, value in metadata_list: | |
| if isinstance(value, TensorMetadata): | |
| tensor = torch.empty( | |
| value.size, dtype=value.dtype, device=value.device | |
| ) | |
| if tensor.numel() == 0: | |
| # Skip broadcasting empty tensors. | |
| tensor_dict[key] = tensor | |
| continue | |
| if tensor.is_cpu: | |
| # use metadata_group for CPU tensors | |
| handle = torch.distributed.broadcast( | |
| tensor, | |
| src=self.ranks[src], | |
| group=metadata_group, | |
| async_op=True, | |
| ) | |
| else: | |
| # use group for GPU tensors | |
| handle = torch.distributed.broadcast( | |
| tensor, src=self.ranks[src], group=group, async_op=True | |
| ) | |
| async_handles.append(handle) | |
| tensor_dict[key] = tensor | |
| else: | |
| tensor_dict[key] = value | |
| for async_handle in async_handles: | |
| async_handle.wait() | |
| return tensor_dict | |
| def send_tensor_dict( | |
| self, | |
| tensor_dict: Dict[str, Union[torch.Tensor, Any]], | |
| dst: Optional[int] = None, | |
| all_gather_group: Optional["GroupCoordinator"] = None, | |
| async_send: bool = False, | |
| ) -> Optional[List[P2PWork]]: | |
| """Send the input tensor dictionary. | |
| NOTE: `dst` is the local rank of the source rank. | |
| """ | |
| # Bypass the function if we are using only 1 GPU. | |
| if self.world_size == 1: | |
| return tensor_dict | |
| all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size | |
| all_gather_rank = ( | |
| 0 if all_gather_group is None else all_gather_group.rank_in_group | |
| ) | |
| group = self.device_group | |
| metadata_group = self.cpu_group | |
| if dst is None: | |
| dst = (self.rank_in_group + 1) % self.world_size | |
| assert dst < self.world_size, f"Invalid dst rank ({dst})" | |
| assert isinstance( | |
| tensor_dict, dict | |
| ), f"Expecting a dictionary, got {type(tensor_dict)}" | |
| metadata_list, tensor_list = _split_tensor_dict(tensor_dict) | |
| # Note: While switching to Device-to-Device (D2D) would introduce an extra | |
| # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks | |
| # show better overall transmission performance with D2D due to: | |
| # 1. Superior D2D transfer bandwidth | |
| # 2. Ability to overlap send and recv operations | |
| # Thus the net performance gain justifies this approach. | |
| send_func = torch.distributed.isend if async_send else torch.distributed.send | |
| p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send) | |
| for tensor in tensor_list: | |
| if tensor.numel() == 0: | |
| # Skip sending empty tensors. | |
| continue | |
| # send-allgather: send only a slice, then do allgather. | |
| if all_gather_group is not None and tensor.numel() % all_gather_size == 0: | |
| tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] | |
| comm_group = metadata_group if tensor.is_cpu else group | |
| work = send_func(tensor, self.ranks[dst], group=comm_group) | |
| if async_send: | |
| p2p_works.append(P2PWork(work, tensor)) | |
| return p2p_works | |
| def recv_tensor_dict( | |
| self, | |
| src: Optional[int] = None, | |
| all_gather_group: Optional["GroupCoordinator"] = None, | |
| ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: | |
| """Recv the input tensor dictionary. | |
| NOTE: `src` is the local rank of the source rank. | |
| """ | |
| # Bypass the function if we are using only 1 GPU. | |
| if not torch.distributed.is_initialized() or self.world_size == 1: | |
| return None | |
| all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size | |
| all_gather_rank = ( | |
| 0 if all_gather_group is None else all_gather_group.rank_in_group | |
| ) | |
| group = self.device_group | |
| metadata_group = self.cpu_group | |
| if src is None: | |
| src = (self.rank_in_group - 1) % self.world_size | |
| assert src < self.world_size, f"Invalid src rank ({src})" | |
| recv_metadata_list = self.recv_object(src=src) | |
| tensor_dict: Dict[str, Any] = {} | |
| for key, value in recv_metadata_list: | |
| if isinstance(value, TensorMetadata): | |
| tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) | |
| if tensor.numel() == 0: | |
| # Skip broadcasting empty tensors. | |
| tensor_dict[key] = tensor | |
| continue | |
| # send-allgather: send only a slice, then do allgather. | |
| use_all_gather = ( | |
| all_gather_group is not None | |
| and tensor.numel() % all_gather_size == 0 | |
| ) | |
| if use_all_gather: | |
| orig_shape = tensor.shape | |
| tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] | |
| # We have to use irecv here to make it work for both isend and send. | |
| comm_group = metadata_group if tensor.is_cpu else group | |
| work = torch.distributed.irecv( | |
| tensor, src=self.ranks[src], group=comm_group | |
| ) | |
| work.wait() | |
| if use_all_gather: | |
| tensor = all_gather_group.all_gather(tensor, dim=0) | |
| tensor = tensor.reshape(orig_shape) | |
| tensor_dict[key] = tensor | |
| else: | |
| tensor_dict[key] = value | |
| return tensor_dict | |
| def barrier(self): | |
| """Barrier synchronization among the group. | |
| NOTE: don't use `device_group` here! `barrier` in NCCL is | |
| terrible because it is internally a broadcast operation with | |
| secretly created GPU tensors. It is easy to mess up the current | |
| device. Use the CPU group instead. | |
| """ | |
| torch.distributed.barrier(group=self.cpu_group) | |
| def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: | |
| """Sends a tensor to the destination rank in a non-blocking way""" | |
| """NOTE: `dst` is the local rank of the destination rank.""" | |
| if dst is None: | |
| dst = (self.rank_in_group + 1) % self.world_size | |
| pynccl_comm = self.pynccl_comm | |
| if pynccl_comm is not None and not pynccl_comm.disabled: | |
| pynccl_comm.send(tensor, dst) | |
| else: | |
| torch.distributed.send(tensor, self.ranks[dst], self.device_group) | |
| def recv( | |
| self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None | |
| ) -> torch.Tensor: | |
| """Receives a tensor from the source rank.""" | |
| """NOTE: `src` is the local rank of the source rank.""" | |
| if src is None: | |
| src = (self.rank_in_group - 1) % self.world_size | |
| tensor = torch.empty(size, dtype=dtype, device=self.device) | |
| pynccl_comm = self.pynccl_comm | |
| if pynccl_comm is not None and not pynccl_comm.disabled: | |
| pynccl_comm.recv(tensor, src) | |
| else: | |
| torch.distributed.recv(tensor, self.ranks[src], self.device_group) | |
| return tensor | |
| def destroy(self): | |
| if self.device_group is not None: | |
| torch.distributed.destroy_process_group(self.device_group) | |
| self.device_group = None | |
| if self.cpu_group is not None: | |
| torch.distributed.destroy_process_group(self.cpu_group) | |
| self.cpu_group = None | |
| if self.pynccl_comm is not None: | |
| self.pynccl_comm = None | |
| if self.ca_comm is not None: | |
| self.ca_comm = None | |
| if self.mq_broadcaster is not None: | |
| self.mq_broadcaster = None | |
| _WORLD: Optional[GroupCoordinator] = None | |
| def get_world_group() -> GroupCoordinator: | |
| assert _WORLD is not None, "world group is not initialized" | |
| return _WORLD | |
| def init_world_group( | |
| ranks: List[int], local_rank: int, backend: str | |
| ) -> GroupCoordinator: | |
| return GroupCoordinator( | |
| group_ranks=[ranks], | |
| local_rank=local_rank, | |
| torch_distributed_backend=backend, | |
| use_pynccl=False, | |
| use_pymscclpp=False, | |
| use_custom_allreduce=False, | |
| use_torch_symm_mem=False, | |
| use_hpu_communicator=False, | |
| use_xpu_communicator=False, | |
| use_npu_communicator=False, | |
| group_name="world", | |
| ) | |
| def init_model_parallel_group( | |
| group_ranks: List[List[int]], | |
| local_rank: int, | |
| backend: str, | |
| use_custom_allreduce: Optional[bool] = None, | |
| use_message_queue_broadcaster: bool = False, | |
| group_name: Optional[str] = None, | |
| use_mscclpp_allreduce: Optional[bool] = None, | |
| pynccl_use_current_stream: bool = True, | |
| use_symm_mem_allreduce: Optional[bool] = None, | |
| torch_compile: Optional[bool] = None, | |
| ) -> GroupCoordinator: | |
| if use_custom_allreduce is None: | |
| use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE | |
| if use_mscclpp_allreduce is None: | |
| use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE | |
| if use_symm_mem_allreduce is None: | |
| use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE | |
| return GroupCoordinator( | |
| group_ranks=group_ranks, | |
| local_rank=local_rank, | |
| torch_distributed_backend=backend, | |
| use_pynccl=not (_is_npu or _is_xpu), | |
| use_pymscclpp=use_mscclpp_allreduce, | |
| use_custom_allreduce=use_custom_allreduce, | |
| use_torch_symm_mem=use_symm_mem_allreduce, | |
| use_hpu_communicator=True, | |
| use_xpu_communicator=True, | |
| use_npu_communicator=True, | |
| use_message_queue_broadcaster=use_message_queue_broadcaster, | |
| group_name=group_name, | |
| pynccl_use_current_stream=pynccl_use_current_stream, | |
| torch_compile=torch_compile, | |
| ) | |
| _TP: Optional[GroupCoordinator] = None | |
| # duplicate GroupCoordinator for prefill in PD-Multiplexing | |
| _PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None | |
| _ENABLE_PDMUX_P_TP: bool = False | |
| def set_pdmux_status(enable_prefill_multiplexing: bool): | |
| global _ENABLE_PDMUX_P_TP | |
| _ENABLE_PDMUX_P_TP = enable_prefill_multiplexing | |
| def get_tp_group() -> GroupCoordinator: | |
| if _ENABLE_PDMUX_P_TP: | |
| assert ( | |
| _PDMUX_PREFILL_TP_GROUP is not None | |
| ), "tensor model parallel group for PD-Multiplexing Prefill is not initialized" | |
| return _PDMUX_PREFILL_TP_GROUP | |
| assert _TP is not None, "tensor model parallel group is not initialized" | |
| return _TP | |
| _MOE_EP: Optional[GroupCoordinator] = None | |
| _MOE_TP: Optional[GroupCoordinator] = None | |
| def get_moe_ep_group() -> GroupCoordinator: | |
| assert _MOE_EP is not None, "expert model parallel group is not initialized" | |
| return _MOE_EP | |
| def get_moe_tp_group() -> GroupCoordinator: | |
| assert _MOE_TP is not None, "expert model parallel group is not initialized" | |
| return _MOE_TP | |
| # kept for backward compatibility | |
| get_tensor_model_parallel_group = get_tp_group | |
| _PP: Optional[GroupCoordinator] = None | |
| def get_pp_group() -> GroupCoordinator: | |
| assert _PP is not None, "pipeline model parallel group is not initialized" | |
| return _PP | |
| # kept for backward compatibility | |
| get_pipeline_model_parallel_group = get_pp_group | |
| def graph_capture(stream: Optional[torch.cuda.Stream] = None): | |
| """ | |
| `graph_capture` is a context manager which should surround the code that | |
| is capturing the CUDA graph. Its main purpose is to ensure that the | |
| some operations will be run after the graph is captured, before the graph | |
| is replayed. It returns a `GraphCaptureContext` object which contains the | |
| necessary data for the graph capture. Currently, it only contains the | |
| stream that the graph capture is running on. This stream is set to the | |
| current CUDA stream when the context manager is entered and reset to the | |
| default stream when the context manager is exited. This is to ensure that | |
| the graph capture is running on a separate stream from the default stream, | |
| in order to explicitly distinguish the kernels to capture | |
| from other kernels possibly launched on background in the default stream. | |
| """ | |
| with get_tp_group().graph_capture( | |
| stream=stream | |
| ) as context, get_pp_group().graph_capture(context): | |
| yield context | |
| logger = logging.getLogger(__name__) | |
| _ENABLE_CUSTOM_ALL_REDUCE = True | |
| _ENABLE_MSCCLPP_ALL_REDUCE = False | |
| _ENABLE_SYMM_MEM_ALL_REDUCE = False | |
| def set_custom_all_reduce(enable: bool): | |
| global _ENABLE_CUSTOM_ALL_REDUCE | |
| _ENABLE_CUSTOM_ALL_REDUCE = enable | |
| def set_mscclpp_all_reduce(enable: bool): | |
| global _ENABLE_MSCCLPP_ALL_REDUCE | |
| _ENABLE_MSCCLPP_ALL_REDUCE = enable | |
| def set_symm_mem_all_reduce(enable: bool): | |
| global _ENABLE_SYMM_MEM_ALL_REDUCE | |
| _ENABLE_SYMM_MEM_ALL_REDUCE = enable | |
| def init_distributed_environment( | |
| world_size: int = -1, | |
| rank: int = -1, | |
| distributed_init_method: str = "env://", | |
| local_rank: int = -1, | |
| backend: str = "nccl", | |
| timeout: Optional[int] = None, | |
| ): | |
| logger.debug( | |
| "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", | |
| world_size, | |
| rank, | |
| local_rank, | |
| distributed_init_method, | |
| backend, | |
| ) | |
| if "mooncake" in backend: | |
| try: | |
| from mooncake import ep as mooncake_ep | |
| except ImportError as e: | |
| raise ImportError( | |
| "Please install mooncake by following the instructions at " | |
| "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 | |
| "to run SGLang with Mooncake Backend." | |
| ) from e | |
| mooncake_ep.set_host_ip(get_local_ip_auto()) | |
| if not torch.distributed.is_initialized(): | |
| assert distributed_init_method is not None, ( | |
| "distributed_init_method must be provided when initializing " | |
| "distributed environment" | |
| ) | |
| if timeout is not None: | |
| assert isinstance(timeout, (int)), "timeout must be a number" | |
| assert timeout > 0, "timeout must be positive" | |
| timeout = timedelta(seconds=timeout) | |
| # this backend is used for WORLD | |
| torch.distributed.init_process_group( | |
| backend=backend, | |
| init_method=distributed_init_method, | |
| world_size=world_size, | |
| rank=rank, | |
| timeout=timeout, | |
| ) | |
| # set the local rank | |
| # local_rank is not available in torch ProcessGroup, | |
| # see https://github.com/pytorch/pytorch/issues/122816 | |
| if local_rank == -1: | |
| # local rank not set, this usually happens in single-node | |
| # setting, where we can use rank as local rank | |
| if distributed_init_method == "env://": | |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) | |
| else: | |
| local_rank = rank | |
| global _WORLD | |
| if _WORLD is None: | |
| ranks = list(range(torch.distributed.get_world_size())) | |
| _WORLD = init_world_group(ranks, local_rank, backend) | |
| else: | |
| assert ( | |
| _WORLD.world_size == torch.distributed.get_world_size() | |
| ), "world group already initialized with a different world size" | |
| def initialize_model_parallel( | |
| tensor_model_parallel_size: int = 1, | |
| expert_model_parallel_size: int = 1, | |
| pipeline_model_parallel_size: int = 1, | |
| backend: Optional[str] = None, | |
| duplicate_tp_group: bool = False, | |
| torch_compile: Optional[bool] = None, | |
| ) -> None: | |
| """ | |
| Initialize model parallel groups. | |
| Arguments: | |
| tensor_model_parallel_size: number of GPUs used for tensor model | |
| parallelism. | |
| pipeline_model_parallel_size: number of GPUs used for pipeline model | |
| parallelism. | |
| Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we | |
| use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize | |
| the model pipeline. The present function will | |
| create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: | |
| 4 tensor model-parallel groups: | |
| [g0, g1], [g2, g3], [g4, g5], [g6, g7] | |
| 2 pipeline model-parallel groups: | |
| [g0, g2, g4, g6], [g1, g3, g5, g7] | |
| Note that for efficiency, the caller should make sure adjacent ranks | |
| are on the same DGX box. For example if we are using 2 DGX-1 boxes | |
| with a total of 16 GPUs, rank 0 to 7 belong to the first box and | |
| ranks 8 to 15 belong to the second box. | |
| """ | |
| # Get world size and rank. Ensure some consistencies. | |
| assert torch.distributed.is_initialized() | |
| world_size: int = torch.distributed.get_world_size() | |
| backend = backend or torch.distributed.get_backend(get_world_group().device_group) | |
| if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: | |
| raise RuntimeError( | |
| f"world_size ({world_size}) is not equal to " | |
| f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " | |
| f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" | |
| ) | |
| # Build the tensor model-parallel groups. | |
| num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size | |
| global _TP | |
| assert _TP is None, "tensor model parallel group is already initialized" | |
| group_ranks = [] | |
| for i in range(num_tensor_model_parallel_groups): | |
| ranks = list( | |
| range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) | |
| ) | |
| group_ranks.append(ranks) | |
| # message queue broadcaster is only used in tensor model parallel group | |
| _TP = init_model_parallel_group( | |
| group_ranks, | |
| get_world_group().local_rank, | |
| backend, | |
| use_message_queue_broadcaster=get_bool_env_var( | |
| "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" | |
| ), | |
| group_name="tp", | |
| pynccl_use_current_stream=duplicate_tp_group, | |
| torch_compile=torch_compile, | |
| ) | |
| if duplicate_tp_group: | |
| global _PDMUX_PREFILL_TP_GROUP | |
| assert ( | |
| _PDMUX_PREFILL_TP_GROUP is None | |
| ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" | |
| _PDMUX_PREFILL_TP_GROUP = init_model_parallel_group( | |
| group_ranks, | |
| get_world_group().local_rank, | |
| backend, | |
| use_message_queue_broadcaster=get_bool_env_var( | |
| "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" | |
| ), | |
| group_name="pdmux_prefill_tp", | |
| pynccl_use_current_stream=True, | |
| torch_compile=torch_compile, | |
| ) | |
| if _TP.pynccl_comm: | |
| _TP.pynccl_comm.disabled = False | |
| _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False | |
| moe_ep_size = expert_model_parallel_size | |
| moe_tp_size = tensor_model_parallel_size // moe_ep_size | |
| global _MOE_EP | |
| assert _MOE_EP is None, "expert model parallel group is already initialized" | |
| if moe_ep_size == tensor_model_parallel_size: | |
| _MOE_EP = _TP | |
| else: | |
| # TODO(ch-wan): use split_group to save memory | |
| group_ranks = [] | |
| for i in range(num_tensor_model_parallel_groups): | |
| for j in range(moe_tp_size): | |
| st = i * tensor_model_parallel_size + j | |
| en = (i + 1) * tensor_model_parallel_size + j | |
| ranks = list(range(st, en, moe_tp_size)) | |
| group_ranks.append(ranks) | |
| _MOE_EP = init_model_parallel_group( | |
| group_ranks, | |
| get_world_group().local_rank, | |
| backend, | |
| group_name="moe_ep", | |
| ) | |
| global _MOE_TP | |
| assert _MOE_TP is None, "expert model parallel group is already initialized" | |
| if moe_tp_size == tensor_model_parallel_size: | |
| _MOE_TP = _TP | |
| else: | |
| # TODO(ch-wan): use split_group to save memory | |
| group_ranks = [] | |
| for i in range(num_tensor_model_parallel_groups): | |
| for j in range(moe_ep_size): | |
| st = i * tensor_model_parallel_size + j * moe_tp_size | |
| en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size | |
| ranks = list(range(st, en)) | |
| group_ranks.append(ranks) | |
| _MOE_TP = init_model_parallel_group( | |
| group_ranks, | |
| get_world_group().local_rank, | |
| backend, | |
| group_name="moe_tp", | |
| ) | |
| # Build the pipeline model-parallel groups. | |
| num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size | |
| global _PP | |
| assert _PP is None, "pipeline model parallel group is already initialized" | |
| group_ranks = [] | |
| for i in range(num_pipeline_model_parallel_groups): | |
| ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) | |
| group_ranks.append(ranks) | |
| # pipeline parallel does not need custom allreduce | |
| _PP = init_model_parallel_group( | |
| group_ranks, | |
| get_world_group().local_rank, | |
| backend, | |
| use_custom_allreduce=False, | |
| group_name="pp", | |
| ) | |
| def ensure_model_parallel_initialized( | |
| tensor_model_parallel_size: int, | |
| expert_model_parallel_size: int, | |
| pipeline_model_parallel_size: int, | |
| backend: Optional[str] = None, | |
| ) -> None: | |
| """Helper to initialize model parallel groups if they are not initialized, | |
| or ensure tensor-parallel and pipeline-parallel sizes are equal to expected | |
| values if the model parallel groups are initialized. | |
| """ | |
| backend = backend or torch.distributed.get_backend(get_world_group().device_group) | |
| if not model_parallel_is_initialized(): | |
| initialize_model_parallel( | |
| tensor_model_parallel_size, | |
| expert_model_parallel_size, | |
| pipeline_model_parallel_size, | |
| backend, | |
| ) | |
| return | |
| assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( | |
| "tensor parallel group already initialized, but of unexpected size: " | |
| f"{get_tensor_model_parallel_world_size()=} vs. " | |
| f"{tensor_model_parallel_size=}" | |
| ) | |
| pp_world_size = get_pp_group().world_size | |
| assert pp_world_size == pipeline_model_parallel_size, ( | |
| "pipeline parallel group already initialized, but of unexpected size: " | |
| f"{pp_world_size=} vs. " | |
| f"{pipeline_model_parallel_size=}" | |
| ) | |
| def model_parallel_is_initialized(): | |
| """Check if tensor and pipeline parallel groups are initialized.""" | |
| return _TP is not None and _PP is not None | |
| _TP_STATE_PATCHED = False | |
| def patch_tensor_parallel_group(tp_group: GroupCoordinator): | |
| """Patch the tp group temporarily until this function ends. | |
| This method is for draft workers of speculative decoding to run draft model | |
| with different tp degree from that of target model workers. | |
| Args: | |
| tp_group (GroupCoordinator): the tp group coordinator | |
| """ | |
| global _TP_STATE_PATCHED | |
| assert not _TP_STATE_PATCHED, "Should not call when it's already patched" | |
| _TP_STATE_PATCHED = True | |
| old_tp_group = get_tp_group() | |
| global _TP | |
| _TP = tp_group | |
| try: | |
| yield | |
| finally: | |
| # restore the original state | |
| _TP_STATE_PATCHED = False | |
| _TP = old_tp_group | |
| def get_world_size(): | |
| """Return world size for the world group.""" | |
| return get_world_group().world_size | |
| def get_world_rank(): | |
| """Return my rank for the world group.""" | |
| return get_world_group().rank_in_group | |
| def get_tensor_model_parallel_world_size(): | |
| """Return world size for the tensor model parallel group.""" | |
| return get_tp_group().world_size | |
| def get_tensor_model_parallel_rank(): | |
| """Return my rank for the tensor model parallel group.""" | |
| return get_tp_group().rank_in_group | |
| def get_pipeline_model_parallel_world_size(): | |
| """Return world size for the pipeline model parallel group.""" | |
| return get_pp_group().world_size | |
| def get_pipeline_model_parallel_rank(): | |
| """Return my rank for the pipeline model parallel group.""" | |
| return get_pp_group().rank_in_group | |
| def get_moe_expert_parallel_world_size(): | |
| """Return world size for the moe expert parallel group.""" | |
| return get_moe_ep_group().world_size | |
| def get_moe_expert_parallel_rank(): | |
| """Return my rank for the moe expert parallel group.""" | |
| return get_moe_ep_group().rank_in_group | |
| def get_moe_tensor_parallel_world_size(): | |
| """Return world size for the moe tensor parallel group.""" | |
| return get_moe_tp_group().world_size | |
| def get_moe_tensor_parallel_rank(): | |
| """Return my rank for the moe tensor parallel group.""" | |
| return get_moe_tp_group().rank_in_group | |
| def destroy_model_parallel(): | |
| """Set the groups to none and destroy them.""" | |
| global _TP | |
| if _TP: | |
| _TP.destroy() | |
| _TP = None | |
| global _PP | |
| if _PP: | |
| _PP.destroy() | |
| _PP = None | |
| global _PDMUX_PREFILL_TP_GROUP | |
| if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr] | |
| _PDMUX_PREFILL_TP_GROUP.destroy() | |
| _PDMUX_PREFILL_TP_GROUP = None | |
| def destroy_distributed_environment(): | |
| global _WORLD | |
| if _WORLD: | |
| _WORLD.destroy() | |
| _WORLD = None | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.destroy_process_group() | |
| def cleanup_dist_env_and_memory(shutdown_ray: bool = False): | |
| destroy_model_parallel() | |
| destroy_distributed_environment() | |
| with contextlib.suppress(AssertionError): | |
| torch.distributed.destroy_process_group() | |
| if shutdown_ray: | |
| import ray # Lazy import Ray | |
| ray.shutdown() | |
| gc.collect() | |
| if not _is_cpu: | |
| if hasattr(torch, "cuda") and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if hasattr(torch._C, "_host_emptyCache"): | |
| torch._C._host_emptyCache() | |
| else: | |
| logger.warning( | |
| "torch._C._host_emptyCache() only available in Pytorch >=2.5" | |
| ) | |
| elif hasattr(torch, "xpu") and torch.xpu.is_available(): | |
| torch.xpu.empty_cache() | |
| elif hasattr(torch, "npu") and torch.npu.is_available(): | |
| torch.npu.empty_cache() | |
| def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: | |
| """ | |
| This is a collective operation that returns if each rank is in the same node | |
| as the source rank. It tests if processes are attached to the same | |
| memory system (shared access to shared memory). | |
| """ | |
| assert ( | |
| torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL | |
| ), "in_the_same_node_as should be tested with a non-NCCL group." | |
| # local rank inside the group | |
| rank = torch.distributed.get_rank(group=pg) | |
| world_size = torch.distributed.get_world_size(group=pg) | |
| # local tensor in each process to store the result | |
| is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) | |
| # global ranks of the processes in the group | |
| ranks = torch.distributed.get_process_group_ranks(pg) | |
| magic_message = b"magic_message" | |
| shm = None | |
| try: | |
| with contextlib.suppress(OSError): | |
| if rank == source_rank: | |
| # create a shared memory segment | |
| shm = shared_memory.SharedMemory(create=True, size=128) | |
| shm.buf[: len(magic_message)] = magic_message | |
| torch.distributed.broadcast_object_list( | |
| [shm.name], src=ranks[source_rank], group=pg | |
| ) | |
| is_in_the_same_node[rank] = 1 | |
| else: | |
| # try to open the shared memory segment | |
| recv = [None] | |
| torch.distributed.broadcast_object_list( | |
| recv, src=ranks[source_rank], group=pg | |
| ) | |
| name = recv[0] | |
| # fix to https://stackoverflow.com/q/62748654/9191338 | |
| # Python incorrectly tracks shared memory even if it is not | |
| # created by the process. The following patch is a workaround. | |
| with patch( | |
| "multiprocessing.resource_tracker.register", | |
| lambda *args, **kwargs: None, | |
| ): | |
| shm = shared_memory.SharedMemory(name=name) | |
| if shm.buf[: len(magic_message)] == magic_message: | |
| is_in_the_same_node[rank] = 1 | |
| except Exception as e: | |
| logger.error("Error ignored in is_in_the_same_node: %s", e) | |
| finally: | |
| if shm: | |
| shm.close() | |
| torch.distributed.barrier(group=pg) | |
| # clean up the shared memory segment | |
| with contextlib.suppress(OSError): | |
| if rank == source_rank and shm: | |
| shm.unlink() | |
| torch.distributed.all_reduce(is_in_the_same_node, group=pg) | |
| return [x == 1 for x in is_in_the_same_node.tolist()] | |
| vllm_get_pp_group = None | |
| vllm_get_tp_group = None | |
| vllm_get_world_group = None | |
| def monkey_patch_vllm_parallel_state(reverse: bool = False): | |
| try: | |
| import vllm.distributed.parallel_state as vllm_parrlel_state | |
| except ImportError: | |
| return | |
| global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group | |
| if vllm_get_pp_group is None: | |
| vllm_get_pp_group = vllm_parrlel_state.get_pp_group | |
| vllm_get_tp_group = vllm_parrlel_state.get_tp_group | |
| vllm_get_world_group = vllm_parrlel_state.get_world_group | |
| if reverse: | |
| setattr(vllm_parrlel_state, "get_pp_group", vllm_get_pp_group) | |
| setattr(vllm_parrlel_state, "get_tp_group", vllm_get_tp_group) | |
| setattr(vllm_parrlel_state, "get_world_group", vllm_get_world_group) | |
| else: | |
| setattr(vllm_parrlel_state, "get_pp_group", get_pp_group) | |
| setattr(vllm_parrlel_state, "get_tp_group", get_tp_group) | |
| setattr(vllm_parrlel_state, "get_world_group", get_world_group) | |
Xet Storage Details
- Size:
- 70.9 kB
- Xet hash:
- 9eab697b448ea3076c0ea974bf699ccded92dee8295890d48992258a6945880c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.