leideng's picture
download
raw
12.3 kB
from __future__ import annotations
import os
import random
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, Optional, Type
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.utils import is_npu
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
#########################
# Constants & Enums
#########################
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
class DisaggregationMode(Enum):
NULL = "null"
PREFILL = "prefill"
DECODE = "decode"
#########################
# Synchronization
#########################
# env var for testing failure, convert to float explicitly
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
def poll_and_all_reduce(pollers, gloo_group):
# at a certain prob, the poll is failed to simulate failure
if FAILURE_PROB > 0:
from sglang.srt.disaggregation.base import KVPoll
polls = [
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
for poller in pollers
]
else:
polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
return tensor_to_reduce.tolist()
#########################
# Metadata Buffers
#########################
class ReqToMetadataIdxAllocator:
"""A memory pool that maps a request to its first output token location."""
def __init__(
self,
size: int,
):
self.size = size
self.free_slots = deque(list(range(size)))
def available_size(self):
return len(self.free_slots)
def alloc(self) -> Optional[int]:
if len(self.free_slots) == 0:
return None
return self.free_slots.popleft()
def free(self, free_index: int):
self.free_slots.append(free_index)
class MetadataBuffers:
def __init__(
self,
size: int,
hidden_size: int,
hidden_states_dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool
device = "cpu"
if is_npu():
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device = "npu"
elif self.custom_mem_pool:
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
device = "cpu"
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.cached_tokens = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device=device
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
# For PD + spec decode
self.output_topk_p = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_topk_index = torch.zeros(
(size, 16), dtype=torch.int64, device=device
)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=hidden_states_dtype, device=device
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
self.output_topk_p.data_ptr(),
self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.cached_tokens.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
self.output_topk_p.nbytes,
self.output_topk_index.nbytes,
self.output_hidden_states.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.cached_tokens[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
self.output_topk_p[0].nbytes,
self.output_topk_index[0].nbytes,
self.output_hidden_states[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.cached_tokens[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
self.output_topk_p[idx],
self.output_topk_index[idx],
self.output_hidden_states[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
# For PD + spec decode
if req.hidden_states_tensor is not None:
# speculative_eagle_topk should not be greater than 16 currently
topk = req.output_topk_p.size(0)
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
req.output_topk_p
)
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
req.output_topk_index
)
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
#########################
# Transfer Backend
#########################
class TransferBackend(Enum):
MOONCAKE = "mooncake"
NIXL = "nixl"
ASCEND = "ascend"
FAKE = "fake"
class KVClassType(Enum):
KVARGS = "kvargs"
MANAGER = "manager"
SENDER = "sender"
RECEIVER = "receiver"
BOOTSTRAP_SERVER = "bootstrap_server"
def get_kv_class(
transfer_backend: TransferBackend, class_type: KVClassType
) -> Optional[Type]:
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: (MooncakeKVReceiver),
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.ASCEND:
from sglang.srt.disaggregation.ascend import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)
from sglang.srt.disaggregation.base import KVArgs
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: AscendKVManager,
KVClassType.SENDER: AscendKVSender,
KVClassType.RECEIVER: (AscendKVReceiver),
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: (NixlKVReceiver),
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.SENDER: FakeKVSender,
KVClassType.RECEIVER: (FakeKVReceiver),
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
#########################
# KV Pages
#########################
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaranteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size
#########################
# Misc
#########################
def is_mla_backend(target_kv_pool) -> bool:
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
return isinstance(target_kv_pool, MLATokenToKVPool)
def prepare_abort(req: Req, error_message: str, status_code=None):
from sglang.srt.managers.schedule_batch import FINISH_ABORT
# populate finish metadata and stream output
req.finished_reason = FINISH_ABORT(error_message, status_code)
if req.return_logprob:
req.input_token_logprobs_val = []
req.input_token_logprobs_idx = []
req.input_top_logprobs_val = []
req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = []

Xet Storage Details

Size:
12.3 kB
·
Xet hash:
b639c5d10b666a116a3e1a081ddca0f483f66b99efdf4074a0f79229a9906c67

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.