| from __future__ import annotations | |
| import os | |
| import random | |
| from collections import deque | |
| from contextlib import nullcontext | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Optional, Type | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from sglang.srt.utils import is_npu | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.schedule_batch import Req | |
| ######################### | |
| # Constants & Enums | |
| ######################### | |
| FAKE_BOOTSTRAP_HOST = "2.2.2.2" | |
| class DisaggregationMode(Enum): | |
| NULL = "null" | |
| PREFILL = "prefill" | |
| DECODE = "decode" | |
| ######################### | |
| # Synchronization | |
| ######################### | |
| # env var for testing failure, convert to float explicitly | |
| FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) | |
| def poll_and_all_reduce(pollers, gloo_group): | |
| # at a certain prob, the poll is failed to simulate failure | |
| if FAILURE_PROB > 0: | |
| from sglang.srt.disaggregation.base import KVPoll | |
| polls = [ | |
| int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll()) | |
| for poller in pollers | |
| ] | |
| else: | |
| polls = [int(poller.poll()) for poller in pollers] | |
| tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") | |
| dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) | |
| return tensor_to_reduce.tolist() | |
| ######################### | |
| # Metadata Buffers | |
| ######################### | |
| class ReqToMetadataIdxAllocator: | |
| """A memory pool that maps a request to its first output token location.""" | |
| def __init__( | |
| self, | |
| size: int, | |
| ): | |
| self.size = size | |
| self.free_slots = deque(list(range(size))) | |
| def available_size(self): | |
| return len(self.free_slots) | |
| def alloc(self) -> Optional[int]: | |
| if len(self.free_slots) == 0: | |
| return None | |
| return self.free_slots.popleft() | |
| def free(self, free_index: int): | |
| self.free_slots.append(free_index) | |
| class MetadataBuffers: | |
| def __init__( | |
| self, | |
| size: int, | |
| hidden_size: int, | |
| hidden_states_dtype: torch.dtype, | |
| max_top_logprobs_num: int = 128, | |
| custom_mem_pool: torch.cuda.MemPool = None, | |
| ): | |
| self.custom_mem_pool = custom_mem_pool | |
| device = "cpu" | |
| if is_npu(): | |
| # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel. | |
| device = "npu" | |
| elif self.custom_mem_pool: | |
| # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free | |
| device = "cpu" | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.custom_mem_pool | |
| else nullcontext() | |
| ): | |
| # TODO: abort top_logprobs_num > 128 in PD | |
| # We transfer the metadata of first output token to decode | |
| # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes | |
| self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) | |
| self.cached_tokens = torch.zeros( | |
| (size, 16), dtype=torch.int32, device=device | |
| ) | |
| self.output_token_logprobs_val = torch.zeros( | |
| (size, 16), dtype=torch.float32, device=device | |
| ) | |
| self.output_token_logprobs_idx = torch.zeros( | |
| (size, 16), dtype=torch.int32, device=device | |
| ) | |
| self.output_top_logprobs_val = torch.zeros( | |
| (size, max_top_logprobs_num), dtype=torch.float32, device=device | |
| ) | |
| self.output_top_logprobs_idx = torch.zeros( | |
| (size, max_top_logprobs_num), dtype=torch.int32, device=device | |
| ) | |
| # For PD + spec decode | |
| self.output_topk_p = torch.zeros( | |
| (size, 16), dtype=torch.float32, device=device | |
| ) | |
| self.output_topk_index = torch.zeros( | |
| (size, 16), dtype=torch.int64, device=device | |
| ) | |
| self.output_hidden_states = torch.zeros( | |
| (size, hidden_size), dtype=hidden_states_dtype, device=device | |
| ) | |
| def get_buf_infos(self): | |
| ptrs = [ | |
| self.output_ids.data_ptr(), | |
| self.cached_tokens.data_ptr(), | |
| self.output_token_logprobs_val.data_ptr(), | |
| self.output_token_logprobs_idx.data_ptr(), | |
| self.output_top_logprobs_val.data_ptr(), | |
| self.output_top_logprobs_idx.data_ptr(), | |
| self.output_topk_p.data_ptr(), | |
| self.output_topk_index.data_ptr(), | |
| self.output_hidden_states.data_ptr(), | |
| ] | |
| data_lens = [ | |
| self.output_ids.nbytes, | |
| self.cached_tokens.nbytes, | |
| self.output_token_logprobs_val.nbytes, | |
| self.output_token_logprobs_idx.nbytes, | |
| self.output_top_logprobs_val.nbytes, | |
| self.output_top_logprobs_idx.nbytes, | |
| self.output_topk_p.nbytes, | |
| self.output_topk_index.nbytes, | |
| self.output_hidden_states.nbytes, | |
| ] | |
| item_lens = [ | |
| self.output_ids[0].nbytes, | |
| self.cached_tokens[0].nbytes, | |
| self.output_token_logprobs_val[0].nbytes, | |
| self.output_token_logprobs_idx[0].nbytes, | |
| self.output_top_logprobs_val[0].nbytes, | |
| self.output_top_logprobs_idx[0].nbytes, | |
| self.output_topk_p[0].nbytes, | |
| self.output_topk_index[0].nbytes, | |
| self.output_hidden_states[0].nbytes, | |
| ] | |
| return ptrs, data_lens, item_lens | |
| def get_buf(self, idx: int): | |
| return ( | |
| self.output_ids[idx], | |
| self.cached_tokens[idx], | |
| self.output_token_logprobs_val[idx], | |
| self.output_token_logprobs_idx[idx], | |
| self.output_top_logprobs_val[idx], | |
| self.output_top_logprobs_idx[idx], | |
| self.output_topk_p[idx], | |
| self.output_topk_index[idx], | |
| self.output_hidden_states[idx], | |
| ) | |
| def set_buf(self, req: Req): | |
| self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] | |
| self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens | |
| if req.return_logprob: | |
| if req.output_token_logprobs_val: # not none or empty list | |
| self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( | |
| req.output_token_logprobs_val[0] | |
| ) | |
| if req.output_token_logprobs_idx: # not none or empty list | |
| self.output_token_logprobs_idx[req.metadata_buffer_index][0] = ( | |
| req.output_token_logprobs_idx[0] | |
| ) | |
| if req.output_top_logprobs_val: # not none or empty list | |
| self.output_top_logprobs_val[req.metadata_buffer_index][ | |
| : len(req.output_top_logprobs_val[0]) | |
| ] = torch.tensor( | |
| req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu" | |
| ) | |
| if req.output_top_logprobs_idx: # not none or empty list | |
| self.output_top_logprobs_idx[req.metadata_buffer_index][ | |
| : len(req.output_top_logprobs_idx[0]) | |
| ] = torch.tensor( | |
| req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" | |
| ) | |
| # For PD + spec decode | |
| if req.hidden_states_tensor is not None: | |
| # speculative_eagle_topk should not be greater than 16 currently | |
| topk = req.output_topk_p.size(0) | |
| self.output_topk_p[req.metadata_buffer_index, :topk].copy_( | |
| req.output_topk_p | |
| ) | |
| self.output_topk_index[req.metadata_buffer_index, :topk].copy_( | |
| req.output_topk_index | |
| ) | |
| self.output_hidden_states[req.metadata_buffer_index].copy_( | |
| req.hidden_states_tensor | |
| ) | |
| ######################### | |
| # Transfer Backend | |
| ######################### | |
| class TransferBackend(Enum): | |
| MOONCAKE = "mooncake" | |
| NIXL = "nixl" | |
| ASCEND = "ascend" | |
| FAKE = "fake" | |
| class KVClassType(Enum): | |
| KVARGS = "kvargs" | |
| MANAGER = "manager" | |
| SENDER = "sender" | |
| RECEIVER = "receiver" | |
| BOOTSTRAP_SERVER = "bootstrap_server" | |
| def get_kv_class( | |
| transfer_backend: TransferBackend, class_type: KVClassType | |
| ) -> Optional[Type]: | |
| from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender | |
| if transfer_backend == TransferBackend.MOONCAKE: | |
| from sglang.srt.disaggregation.base import KVArgs | |
| from sglang.srt.disaggregation.mooncake import ( | |
| MooncakeKVBootstrapServer, | |
| MooncakeKVManager, | |
| MooncakeKVReceiver, | |
| MooncakeKVSender, | |
| ) | |
| class_mapping = { | |
| KVClassType.KVARGS: KVArgs, | |
| KVClassType.MANAGER: MooncakeKVManager, | |
| KVClassType.SENDER: MooncakeKVSender, | |
| KVClassType.RECEIVER: (MooncakeKVReceiver), | |
| KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, | |
| } | |
| return class_mapping.get(class_type) | |
| elif transfer_backend == TransferBackend.ASCEND: | |
| from sglang.srt.disaggregation.ascend import ( | |
| AscendKVBootstrapServer, | |
| AscendKVManager, | |
| AscendKVReceiver, | |
| AscendKVSender, | |
| ) | |
| from sglang.srt.disaggregation.base import KVArgs | |
| class_mapping = { | |
| KVClassType.KVARGS: KVArgs, | |
| KVClassType.MANAGER: AscendKVManager, | |
| KVClassType.SENDER: AscendKVSender, | |
| KVClassType.RECEIVER: (AscendKVReceiver), | |
| KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer, | |
| } | |
| return class_mapping.get(class_type) | |
| elif transfer_backend == TransferBackend.NIXL: | |
| from sglang.srt.disaggregation.base import KVArgs | |
| from sglang.srt.disaggregation.nixl import ( | |
| NixlKVBootstrapServer, | |
| NixlKVManager, | |
| NixlKVReceiver, | |
| NixlKVSender, | |
| ) | |
| class_mapping = { | |
| KVClassType.KVARGS: KVArgs, | |
| KVClassType.MANAGER: NixlKVManager, | |
| KVClassType.SENDER: NixlKVSender, | |
| KVClassType.RECEIVER: (NixlKVReceiver), | |
| KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, | |
| } | |
| return class_mapping.get(class_type) | |
| elif transfer_backend == TransferBackend.FAKE: | |
| from sglang.srt.disaggregation.base import KVArgs | |
| from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender | |
| class_mapping = { | |
| KVClassType.KVARGS: KVArgs, | |
| KVClassType.SENDER: FakeKVSender, | |
| KVClassType.RECEIVER: (FakeKVReceiver), | |
| } | |
| return class_mapping.get(class_type) | |
| raise ValueError(f"Unsupported transfer backend: {transfer_backend}") | |
| ######################### | |
| # KV Pages | |
| ######################### | |
| def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): | |
| # 1. The page is guaranteed to be full except the last page. | |
| # 2. page index = kv_index // page_size | |
| # The return vector is kv_indices[::page_size] // page_size | |
| if page_size == 1: # shortcut | |
| return kv_indices | |
| return kv_indices[::page_size] // page_size | |
| def kv_to_page_num(num_kv_indices: int, page_size: int): | |
| # ceil(num_kv_indices / page_size) | |
| return (num_kv_indices + page_size - 1) // page_size | |
| ######################### | |
| # Misc | |
| ######################### | |
| def is_mla_backend(target_kv_pool) -> bool: | |
| from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool | |
| return isinstance(target_kv_pool, MLATokenToKVPool) | |
| def prepare_abort(req: Req, error_message: str, status_code=None): | |
| from sglang.srt.managers.schedule_batch import FINISH_ABORT | |
| # populate finish metadata and stream output | |
| req.finished_reason = FINISH_ABORT(error_message, status_code) | |
| if req.return_logprob: | |
| req.input_token_logprobs_val = [] | |
| req.input_token_logprobs_idx = [] | |
| req.input_top_logprobs_val = [] | |
| req.input_top_logprobs_idx = [] | |
| req.input_token_ids_logprobs_val = [] | |
| req.input_token_ids_logprobs_idx = [] | |
Xet Storage Details
- Size:
- 12.3 kB
- Xet hash:
- b639c5d10b666a116a3e1a081ddca0f483f66b99efdf4074a0f79229a9906c67
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.