leideng/QCFuse / srt /managers /schedule_batch.py
leideng's picture
download
raw
90.4 kB
from __future__ import annotations
import enum
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Store information about requests and batches.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
"""
import copy
import dataclasses
import logging
import re
import time
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.environ import envs
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
from sglang.srt.mem_cache.common import (
alloc_for_decode,
alloc_for_extend,
evict_from_tree_cache,
)
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.cache_blender_info import (
BatchBlendInfo,
BlendStyle,
SelectMode,
AttParams,
HackBlendKVPool,
ContextBlendPool,
)
from sglang.srt.utils.digest_index_manager import DigestIndexManager
from sglang.srt.utils.kv_ssd_manager import KVSSDManager
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
logger = logging.getLogger(__name__)
class BaseFinishReason:
def __init__(self, is_error: bool = False):
self.is_error = is_error
def to_json(self):
raise NotImplementedError()
class FINISH_MATCHED_TOKEN(BaseFinishReason):
def __init__(self, matched: Union[int, List[int]]):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_MATCHED_STR(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISHED_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int):
super().__init__()
self.length = length
def to_json(self):
return {
"type": "length", # to match OpenAI API's return value
"length": self.length,
}
class FINISH_ABORT(BaseFinishReason):
def __init__(self, message=None, status_code=None, err_type=None):
super().__init__(is_error=True)
self.message = message or "Aborted"
self.status_code = status_code
self.err_type = err_type
def to_json(self):
return {
"type": "abort",
"message": self.message,
"status_code": self.status_code,
"err_type": self.err_type,
}
class Modality(Enum):
IMAGE = auto()
MULTI_IMAGES = auto()
VIDEO = auto()
AUDIO = auto()
@staticmethod
def from_str(modality_str: str):
modality = modality_str.upper()
if modality not in Modality.__members__:
raise ValueError(
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
return Modality[modality]
@staticmethod
def all():
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
@dataclasses.dataclass
class MultimodalDataItem:
"""
One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields in model_specific_data.
"""
modality: Modality
hash: int = None
pad_value: int = None
offsets: Optional[list] = None
# the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None
# the precomputed embeddings, passed as final encoder embeddings
# One and only one of the feature and precomputed_embeddings will be empty
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
# Model-specific data stored in a dictionary
model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
def __getattr__(self, name: str):
if (
"model_specific_data" in self.__dict__
and name in self.__dict__["model_specific_data"]
):
return self.__dict__["model_specific_data"][name]
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def __setitem__(self, key: str, value: Any):
if key in self.__dict__:
self.__dict__[key] = value
else:
self.model_specific_data[key] = value
def set(self, key: str, value: Any):
self.__setitem__(key, value)
@staticmethod
def is_empty_list(l):
if l is None:
return True
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
def set_pad_value(self):
"""
Set the pad value after first hashing the data
"""
from sglang.srt.managers.mm_utils import hash_feature
if self.hash is None:
if self.feature is not None:
hashed_feature = self.feature
else:
hashed_feature = self.precomputed_embeddings
self.hash = hash_feature(hashed_feature)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_modality(self, modality: Modality) -> bool:
return self.modality == modality
def is_audio(self):
return self.modality == Modality.AUDIO
def is_image(self):
return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
def is_video(self):
return self.modality == Modality.VIDEO
def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio()
def validate(self):
...
# TODO
@staticmethod
def from_dict(obj: dict):
kwargs = dict(obj)
modality = kwargs.pop("modality")
if isinstance(modality, str):
modality = Modality[modality]
ret = MultimodalDataItem(modality=modality, **kwargs)
ret.validate()
return ret
def merge(self, other):
self.feature += other.feature
self.offsets += other.offsets
self.hash = hash((self.hash, other.hash))
self.set_pad_value()
@dataclasses.dataclass
class MultimodalInputs:
"""The multimodal data related inputs."""
# items of data
mm_items: List[MultimodalDataItem]
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None
# image
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None
# video
video_token_id: Optional[int] = None
# audio
audio_token_id: Optional[int] = None
audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod
def from_dict(obj: dict):
ret = MultimodalInputs(
mm_items=obj["mm_items"],
)
assert isinstance(ret.mm_items, list)
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
for item in ret.mm_items:
item.set_pad_value()
optional_args = [
"mrope_positions",
"mrope_position_delta",
"im_token_id",
"im_start_id",
"im_end_id",
"video_token_id",
"slice_start_id",
"slice_end_id",
"audio_start_id",
"audio_end_id",
"audio_token_id",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
return ret
def contains_image_inputs(self) -> bool:
return any(item.is_image() for item in self.mm_items)
def contains_video_inputs(self) -> bool:
return any(item.is_video() for item in self.mm_items)
def contains_audio_inputs(self) -> bool:
return any(item.is_audio() for item in self.mm_items)
def contains_mm_input(self) -> bool:
return any(True for item in self.mm_items if item.is_valid())
def merge(self, other: MultimodalInputs):
"""
merge image inputs when requests are being merged
"""
# args needed to be merged
optional_args = [
"mm_items",
"image_pad_len",
]
for arg in optional_args:
self_arg = getattr(self, arg, None)
if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
mrope_positions = self.mrope_positions
if mrope_positions is not None:
if other.mrope_positions is None:
self.mrope_positions = mrope_positions
else:
self.mrope_positions = torch.cat(
[self.mrope_positions, other.mrope_positions], dim=1
)
mrope_position_delta = self.mrope_position_delta
if mrope_position_delta is not None:
if other.mrope_position_delta is None:
self.mrope_position_delta = mrope_position_delta
else:
self.mrope_position_delta = torch.cat(
[self.mrope_position_delta, other.mrope_position_delta], dim=0
)
for key, val in other.__dict__.items():
if "_id" in key:
# set token_ids
if getattr(self, key, None) is None:
setattr(self, key, getattr(other, key, None))
# other args would be kept intact
class RequestStage(str, enum.Enum):
# prefill
PREFILL_WAITING = "prefill_waiting"
# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
class Req:
"""The input and output status of a request."""
def __init__(
self,
rid: str,
origin_input_text: str,
origin_input_ids: List[int],
sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
token_ids_logprob: List[int] = None,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_id: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
disagg_mode: Optional[DisaggregationMode] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None,
http_worker_ipc: Optional[str] = None,
blend_loc_list: Optional[List[int]] = None,
blend_style: str = None,
method: Optional[str] = None,
start: Optional[int] = None,
ratio: Optional[float] = None,
attn_start: Optional[int] = 0,
attn_end: Optional[int] = -1,
is_contextblend: Optional[bool] = False,
context_cache_source: Optional[str] = "query",
context_n_sink: Optional[int] = None,
digest_ratio: Optional[float] = 0.3,
digest_index_method: Optional[str] = "kvzip",
critical_layers: Optional[List[int]] = None,
# SSD pipeline parameters
ssd_cache_path_chunk: Optional[str] = None,
ssd_cache_path_query: Optional[str] = None,
):
# Input and output info
self.rid = rid
self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = (
origin_input_ids_unpadded
if origin_input_ids_unpadded
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids
# Each decode stage's output ids
self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.fill_ids = []
self.session_id = session_id
self.input_embeds = input_embeds
# for corss-endoder model
self.token_type_ids = token_type_ids
# The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0
# For multi-http worker
self.http_worker_ipc = http_worker_ipc
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
sampling_params.custom_params = sampling_params.custom_params | {
"__req__": self
}
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
) + lora_id # lora_id is concatenated to the extra key
self.extra_key = extra_key
self.lora_id = lora_id
# Memory pool info
self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
# Check finish
self.tokenizer = None
self.finished_reason = None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self.finished_len = None
# Whether this request has finished output
self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = None
self.stream = stream
self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size
self.priority = priority
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
# ----- ^ ----------- ^ ----------- ^
# ----- 1 ----------- 2 ----------- 3
# 1: surr_offset
# 2: read_offset
# 3: last token
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
self.decoded_text = ""
# For multimodal inputs
self.multimodal_inputs: Optional[MultimodalInputs] = None
# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
# Number of tokens to run prefill.
self.extend_input_len = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node: Any = None
self.last_host_node: Any = None
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# The prefix length of the last prefix matching
self.last_matched_prefix_len: int = 0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
# processed.
self.is_chunked = 0
# For retraction
self.is_retracted = False
# Incremental streamining
self.send_token_offset: int = 0
self.send_decode_id_offset: int = 0
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
# because the decode server does not have the first output token logprobs
self.send_output_token_logprobs_offset: int = 0
# Logprobs (arguments)
self.return_logprob = return_logprob
# Start index to compute logprob from.
self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num
self.token_ids_logprob = token_ids_logprob
self.temp_scaled_logprobs = False
self.top_p_normalized_logprobs = False
# Logprobs (return values)
# True means the input logprob has been already sent to detokenizer.
self.input_logprob_sent: bool = False
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx: Optional[List[int]] = None
self.input_token_ids_logprobs_val: Optional[List[float]] = None
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
# Temporary holder to store input_token_logprobs.
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
if return_logprob:
# shape: (bs, 1)
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
# shape: (bs, k)
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self.output_token_ids_logprobs_val: List[
Union[List[float], torch.Tensor]
] = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
self.output_topk_p = None
self.output_topk_index = None
# Embedding (return values)
self.embedding = None
# Constrained decoding
self.grammar: Optional[BaseGrammarObject] = None
self.grammar_wait_ct = 0
# The number of cached tokens that were already cached in the KV cache
self.cached_tokens = 0
self.already_computed = 0
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
# The number of accepted tokens in speculative decoding for this request.
# This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0
# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
self.has_log_time_stats: bool = False
self.last_tic = time.monotonic()
# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: Optional[int] = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None
# For data parallel rank routing
self.data_parallel_rank: Optional[int] = data_parallel_rank
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
# kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
# start_send_idx = len(req.fill_ids)
self.start_send_idx: int = 0
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1
# For cache blend
self.blend_style: Optional[BlendStyle] = BlendStyle.parse(blend_style)
self.blend_loc_list = blend_loc_list
self.method = method
self.start = start
self.ratio = ratio
self.attn_start = attn_start
self.attn_end = attn_end
self.is_contextblend = is_contextblend
self.context_cache_source = context_cache_source
self.context_n_sink = context_n_sink
self.digest_ratio = digest_ratio
self.digest_index_method = digest_index_method
self.critical_layers = critical_layers
# SSD pipeline parameters
self.ssd_cache_path_chunk = ssd_cache_path_chunk
self.ssd_cache_path_query = ssd_cache_path_query
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
@property
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
@property
def output_ids_through_stop(self) -> List[int]:
"""Get the output ids through the stop condition. Stop position is included."""
if self.finished_len is not None:
return self.output_ids[: self.finished_len]
return self.output_ids
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
now = time.monotonic()
self.metrics_collector.observe_per_stage_req_latency(
stage.value, now - self.last_tic
)
self.last_tic = now
def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
else:
self.multimodal_inputs.merge(image_inputs)
def finished(self) -> bool:
# Whether request reached finished condition
return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
input_len = len(self.fill_ids)
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
max_prefix_len = input_len - 1
if self.return_logprob:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
token_ids = self.fill_ids[:max_prefix_len]
if tree_cache is not None:
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None
output_ids = self.output_ids_through_stop
if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
)
self.surr_and_decode_ids = (
self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
)
self.cur_decode_ids_len = len(output_ids)
else:
self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(output_ids)
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
# Check stop strings and stop regex patterns together
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
max_len_tail_str = max(
self.sampling_params.stop_str_max_len + 1,
self.sampling_params.stop_regex_max_len + 1,
)
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
"""
Check if the suffix of tail_str overlaps with any stop_str prefix
"""
if not self.sampling_params.stop_strs:
return False
tail_str = self.tail_str()
# Early return if tail_str is empty
if not tail_str:
return False
for stop_str in self.sampling_params.stop_strs:
if not stop_str:
continue
# Check if stop_str is contained in tail_str (fastest check first)
if stop_str in tail_str:
return True
# Check if tail_str suffix matches stop_str prefix
# Only check if stop_str is not empty, it's for stream output
min_len = min(len(tail_str), len(stop_str))
for i in range(1, min_len + 1):
if tail_str[-i:] == stop_str[:i]:
return True
return False
def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
if self.sampling_params.ignore_eos:
return False
# Check stop token ids
matched_eos = False
for i, token_id in enumerate(new_accepted_tokens):
if self.sampling_params.stop_token_ids:
matched_eos |= token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids:
matched_eos |= token_id in self.eos_token_ids
if self.tokenizer is not None:
matched_eos |= token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
self.finished_len = matched_pos + 1
return True
return False
def _check_str_based_finish(self):
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return True
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return True
return False
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
for i, token_id in enumerate(new_accepted_tokens):
if token_id > self.vocab_size or token_id < 0:
offset = len(self.output_ids) - len(new_accepted_tokens) + i
if self.sampling_params.stop_token_ids:
self.output_ids[offset] = next(
iter(self.sampling_params.stop_token_ids)
)
if self.eos_token_ids:
self.output_ids[offset] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
self.finished_len = offset + 1
return True
return False
def check_finished(self, new_accepted_len: int = 1):
if self.finished():
return
if self.to_abort:
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
self.finished_len = self.sampling_params.max_new_tokens
return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
new_accepted_tokens = self.output_ids[-new_accepted_len:]
if self._check_token_based_finish(new_accepted_tokens):
return
if self._check_vocab_boundary_finish(new_accepted_tokens):
return
if self._check_str_based_finish():
return
def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.last_node = None
self.swa_uuid_for_lock = None
self.extend_input_len = 0
self.is_retracted = True
self.input_token_logprobs = None
self.temp_input_top_logprobs_val = None
self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.mamba_pool_idx = None
self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
del self.kv_cache_cpu
def log_time_stats(self):
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
if self.has_log_time_stats is True:
return
if self.bootstrap_room is not None:
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
else:
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
self.has_log_time_stats = True
def set_finish_with_abort(self, error_msg: str):
if get_tensor_model_parallel_rank() == 0:
logger.error(f"{error_msg}, {self.rid=}")
self.multimodal_inputs = None
self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill
self.return_logprob = False
self.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
def __repr__(self):
return (
f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
f"{self.grammar=}, "
f"{self.sampling_params=})"
)
@dataclasses.dataclass
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
"""Store all information of a batch on the scheduler."""
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
is_hybrid: bool = False
# Batch configs
model_config: ModelConfig = None
forward_mode: ForwardMode = None
enable_overlap: bool = False
# Tell whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False
# For chunked prefill in PP
chunked_req: Optional[Req] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None # shape: [b], int64
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64
# For multimodal inputs
multimodal_inputs: Optional[List] = None
# The sum of all sequence lengths
seq_lens_sum: int = None
# The original sequence lengths, Qwen-1M related
orig_seq_lens: torch.Tensor = None # shape: [b], int32
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprob post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: Optional[int] = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
# For encoder-decoder architectures
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# Stream
has_stream: bool = False
# Has grammar
has_grammar: bool = False
# Device
device: str = "cuda"
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
# spec_info: Optional[SpecInput] = None
spec_info: Optional[SpecInput] = None
# Whether to return hidden states
return_hidden_states: bool = False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = -1
blend_info_list: Optional[BatchBlendInfo] = None
@classmethod
def init_new(
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
chunked_req: Optional[Req] = None,
):
return_logprob = any(req.return_logprob for req in reqs)
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert (
tree_cache is None
or isinstance(tree_cache, SWARadixCache)
or isinstance(tree_cache, SWAChunkCache)
), (
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
)
is_hybrid = True
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
is_hybrid=is_hybrid,
model_config=model_config,
enable_overlap=enable_overlap,
return_logprob=return_logprob,
has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
)
def batch_size(self):
return len(self.reqs)
def is_empty(self):
return len(self.reqs) == 0
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
for req in self.reqs:
im = req.multimodal_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
self.encoder_cached.append(True)
else:
self.encoder_lens_cpu.append(im.num_image_tokens)
self.encoder_cached.append(
self.forward_mode.is_decode()
or len(req.prefix_indices) >= im.num_image_tokens
)
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
self.device, non_blocking=True
)
# Strip encoder infos
pt = 0
decoder_out_cache_loc = []
encoder_out_cache_loc = []
for i, req in enumerate(self.reqs):
encoder_len = self.encoder_lens_cpu[i]
seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should be considered as a whole
assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
decoder_out_cache_loc.append(
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
)
self.extend_lens[i] -= encoder_len
self.extend_num_tokens -= encoder_len
else:
decoder_out_cache_loc.append(
self.out_cache_loc[pt : pt + req.extend_input_len]
)
self.prefix_lens[i] -= encoder_len
pt += req.extend_input_len
# Reassign
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
self.device, non_blocking=True
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
self.device, non_blocking=True
)
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens, (
f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
)
# now we only support extend mode for blend
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
# Init tensors
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs]
token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None
]
input_ids_tensor = torch.tensor(
list(chain.from_iterable(input_ids)), dtype=torch.int64
).to(self.device, non_blocking=True)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
token_type_ids_tensor = None
if len(token_type_ids) > 0:
token_type_ids_tensor = torch.tensor(
sum(token_type_ids, []), dtype=torch.int64
).to(self.device, non_blocking=True)
# Set batch fields needed by alloc_for_extend
self.prefix_lens = prefix_lens
self.extend_lens = extend_lens
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens
def _build_att_params():
return AttParams(
num_heads=self.model_config.num_attention_heads,
head_dim=self.model_config.head_dim
if self.model_config.head_dim is not None
else self.model_config.hidden_size
// self.model_config.num_attention_heads,
num_kv_heads=self.model_config.get_num_kv_heads(
get_tensor_model_parallel_world_size()
),
num_layers=self.model_config.num_hidden_layers,
)
if reqs[0].blend_style is None:
self.blend_info_list = None
else:
first_style = reqs[0].blend_style
has_ssd_paths = (
getattr(reqs[0], "ssd_cache_path_chunk", None) is not None
or getattr(reqs[0], "ssd_cache_path_query", None) is not None
)
attn_start = reqs[0].attn_start
attn_end = reqs[0].attn_end
raw_blend_locs = [list(r.blend_loc_list) for r in reqs]
forward_blend_locs = raw_blend_locs
digest_original_locs = None
digest_keep_indices = None
digest_aug_sys_range = None
digest_aug_doc_ranges = None
digest_aug_zip_ranges = None
is_query_offline = (
first_style == BlendStyle.KVCOMPUTE
and reqs[0].is_contextblend
and has_ssd_paths
and getattr(reqs[0], "ssd_cache_path_query", None) is not None
)
if is_query_offline:
transformed = []
for req, raw_locs in zip(reqs, raw_blend_locs):
transformed.append(
DigestIndexManager.prepare_augmented_locs_for_request(
raw_locs
)
)
if not transformed or any(x is None for x in transformed):
raise ValueError(
"ContextBlend query KVCOMPUTE expects augmented chunks "
"with layout sys, doc, zipprompt, ..., query."
)
forward_blend_locs = [x["forward_locs"] for x in transformed]
original_lens = []
keep_indices = []
aug_doc_ranges = []
aug_zip_ranges = []
seq_offset = 0
for item, req in zip(transformed, reqs):
original_req_locs = item["original_locs"]
original_lens.extend(
original_req_locs[i + 1] - original_req_locs[i]
for i in range(len(original_req_locs) - 1)
)
keep_indices.extend(
seq_offset + int(idx) for idx in item["keep_indices"]
)
if digest_aug_sys_range is None:
start, end = item["aug_sys_range"]
digest_aug_sys_range = (
seq_offset + int(start),
seq_offset + int(end),
)
aug_doc_ranges.extend(
(seq_offset + int(start), seq_offset + int(end))
for start, end in item["aug_doc_ranges"]
)
aug_zip_ranges.extend(
(seq_offset + int(start), seq_offset + int(end))
for start, end in item["aug_zip_ranges"]
)
seq_offset += len(req.origin_input_ids)
digest_original_locs = DigestIndexManager._cumsum_lens(
original_lens
)
digest_keep_indices = keep_indices
digest_aug_doc_ranges = aug_doc_ranges
digest_aug_zip_ranges = aug_zip_ranges
blend_locs = [
torch.tensor(locs, dtype=torch.int64, device=self.device)
for locs in forward_blend_locs
]
chunk_lens = torch.cat([torch.diff(loc) for loc in blend_locs])
chunk_loc_list = F.pad(torch.cumsum(chunk_lens, 0), (1, 0))
req_len_list = torch.cumsum(
torch.tensor(
[len(loc) - 1 for loc in blend_locs],
dtype=torch.int32,
device=self.device,
),
0,
)
self.blend_info_list = BatchBlendInfo()
self.blend_info_list.blend_style = first_style
self.blend_info_list.chunk_lens = chunk_lens
self.blend_info_list.chunk_loc_list = chunk_loc_list
self.blend_info_list.req_len_list = req_len_list
self.blend_info_list.attn_start = attn_start
self.blend_info_list.attn_end = (
attn_end if attn_end != -1 else self.model_config.num_hidden_layers
)
self.blend_info_list.is_contextblend = reqs[0].is_contextblend
context_cache_source = (
(getattr(reqs[0], "context_cache_source", None) or "query")
if reqs[0].is_contextblend
else "none"
)
if context_cache_source not in ("query", "none"):
raise ValueError(
f"Unsupported context_cache_source={context_cache_source!r}"
)
self.blend_info_list.context_cache_source = context_cache_source
self.blend_info_list.context_n_sink = (
reqs[0].context_n_sink if reqs[0].context_n_sink is not None else 4
)
self.blend_info_list.digest_ratio = (
reqs[0].digest_ratio if reqs[0].digest_ratio is not None else 0.3
)
self.blend_info_list.digest_index_method = (
reqs[0].digest_index_method or "kvzip"
)
critical_layers = [
int(x) for x in (getattr(reqs[0], "critical_layers", None) or [])
]
self.blend_info_list.critical_layers = critical_layers
self.blend_info_list.critical_layers_set = set(critical_layers)
self.blend_info_list.qcompute_end = (
max(critical_layers) + 1 if critical_layers else None
)
if digest_original_locs is not None:
self.blend_info_list.digest_original_chunk_loc_list = torch.tensor(
digest_original_locs, dtype=torch.int64, device=self.device
)
self.blend_info_list.digest_keep_indices = digest_keep_indices
self.blend_info_list.digest_aug_sys_range = digest_aug_sys_range
self.blend_info_list.digest_aug_doc_ranges = digest_aug_doc_ranges
self.blend_info_list.digest_aug_zip_ranges = digest_aug_zip_ranges
# SSD pipeline: init buffers BEFORE setting q_lens (init_buffers clears q_lens)
if has_ssd_paths and first_style == BlendStyle.QCOMPUTE:
num_layers = self.model_config.num_hidden_layers
HackBlendKVPool.init_buffers(num_layers)
ContextBlendPool.init_buffers(num_layers)
KVSSDManager.configure(
offline=False,
online=True,
sample_dir_chunk=reqs[0].ssd_cache_path_chunk,
sample_dir_query=reqs[0].ssd_cache_path_query,
num_layers=num_layers,
device=str(self.device),
)
if first_style in [BlendStyle.QCOMPUTE, BlendStyle.KVCOMPUTE]:
if first_style == BlendStyle.KVCOMPUTE:
self.blend_info_list.att_params = _build_att_params()
if first_style == BlendStyle.QCOMPUTE:
req_boundaries = torch.cat(
[
torch.tensor([0], dtype=torch.long, device=self.device),
req_len_list.to(dtype=torch.long),
]
)
second_idxs = req_boundaries[:-1] + 1
req_start_idxs = req_boundaries[:-1]
req_starts = chunk_loc_list[req_start_idxs]
starts = chunk_loc_list[second_idxs]
ends = chunk_loc_list[second_idxs + 1]
lengths = ends - starts
q_offsets = starts - req_starts
query_lengths = q_offsets + lengths
HackBlendKVPool.q_lens = lengths.tolist()
HackBlendKVPool.q_offsets = q_offsets.tolist()
HackBlendKVPool.query_k_lens = query_lengths.tolist()
if lengths.numel() > 0:
max_len = lengths.max().item()
range_row = torch.arange(max_len, device=self.device)
mask = range_row < lengths.unsqueeze(1)
raw_indices = starts.unsqueeze(1) + range_row
self.blend_info_list.quest_indices = raw_indices[mask]
if query_lengths.numel() > 0:
max_query_len = query_lengths.max().item()
query_range_row = torch.arange(max_query_len, device=self.device)
query_mask = query_range_row < query_lengths.unsqueeze(1)
raw_query_indices = req_starts.unsqueeze(1) + query_range_row
self.blend_info_list.query_indices = raw_query_indices[query_mask]
else:
self.blend_info_list.init_attmeta = True
else:
# DO_BLEND
self.blend_info_list.select_mode = SelectMode(reqs[0].method)
self.blend_info_list.ratio = reqs[0].ratio
self.blend_info_list.start = reqs[0].start
self.blend_info_list.att_params = _build_att_params()
# Derive keep_layers_set: layers already loaded before DO_BLEND that should
# be preserved across ratio rounds (not cleared between DO_BLEND rounds).
# Only ATTN mode has preloaded selection layers.
num_layers = self.model_config.num_hidden_layers
select_mode = self.blend_info_list.select_mode
if select_mode == SelectMode.ATTN:
critical_layers = self.blend_info_list.critical_layers or []
attn_start_val = self.blend_info_list.attn_start
attn_end_val = self.blend_info_list.attn_end
if critical_layers:
self.blend_info_list.keep_layers_set = set(
int(x) for x in critical_layers
)
elif attn_start_val == 0 and attn_end_val == num_layers:
self.blend_info_list.keep_layers_set = set(range(num_layers))
else:
self.blend_info_list.keep_layers_set = {attn_start_val}
else:
self.blend_info_list.keep_layers_set = set()
# SSD pipeline auto-configuration
if has_ssd_paths:
num_layers = self.model_config.num_hidden_layers
attn_start_val = self.blend_info_list.attn_start
attn_end_val = self.blend_info_list.attn_end
if first_style == BlendStyle.KVCOMPUTE:
# Offline: save KV to SSD
KVSSDManager.configure(
offline=True,
online=False,
sample_dir_chunk=reqs[0].ssd_cache_path_chunk,
sample_dir_query=reqs[0].ssd_cache_path_query,
num_layers=num_layers,
)
elif first_style == BlendStyle.QCOMPUTE:
# SSD init already done above (before q_lens setup)
critical_layers = self.blend_info_list.critical_layers or []
qcompute_end = int(
self.blend_info_list.qcompute_end or attn_end_val
)
context_source = self.blend_info_list.context_cache_source
def _selection_chunk_layers():
if critical_layers:
return [int(x) for x in critical_layers]
if attn_start_val == 0 and attn_end_val == num_layers:
return list(range(num_layers))
return [attn_start_val]
def _start_selection_chunk_loader():
task_b_layers = _selection_chunk_layers()
if task_b_layers:
task_b = KVSSDManager.start_task_b(task_b_layers)
task_b.start()
if reqs[0].is_contextblend:
if context_source == "query":
if not critical_layers:
raise ValueError(
"context_cache_source='query' requires "
"critical_layers so critical KV can be loaded "
"from query_cache"
)
query_path = reqs[0].ssd_cache_path_query
query_meta = KVSSDManager.restore_query_metadata(
query_path,
digest_index_method=self.blend_info_list.digest_index_method,
digest_ratio=self.blend_info_list.digest_ratio,
)
query_task = KVSSDManager.start_query_cache(
qcompute_end, critical_layers, query_meta=query_meta
)
query_task.start()
else:
_start_selection_chunk_loader()
else:
_start_selection_chunk_loader()
# DO_BLEND: start layer prefetch thread
# if first_style in (BlendStyle.DO_BLEND, BlendStyle.DO_BLEND_FINISH):
else:
# First DO_BLEND without prior QCOMPUTE needs full online initialization.
if not KVSSDManager.is_online():
HackBlendKVPool.init_buffers(num_layers)
KVSSDManager.configure(
offline=False,
online=True,
sample_dir_chunk=reqs[0].ssd_cache_path_chunk,
sample_dir_query=reqs[0].ssd_cache_path_query,
num_layers=num_layers,
device=str(self.device),
)
# Clear stale layer events from QCOMPUTE before DO_BLEND prefetch
KVSSDManager.reset_layer_events()
select_mode = self.blend_info_list.select_mode
start_layer = self.blend_info_list.start
if self.blend_info_list.ratio <= 0 or select_mode == SelectMode.ATTN:
prefetch_from = start_layer + 1
else:
prefetch_from = start_layer
task_blend = KVSSDManager.start_do_blend_prefetch(prefetch_from, num_layers)
task_blend.start()
# Allocate memory
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
self
)
# Set fields
input_embeds = []
extend_input_logprob_token_ids = []
multimodal_inputs = []
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
req.req_pool_idx = req_pool_indices[i]
assert seq_len - pre_len == req.extend_input_len
# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
multimodal_inputs.append(req.multimodal_inputs)
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
# Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len:
# Optimization for prefill-only requests: When we only need logprobs at
# positions beyond the input sequence (to score next-token likelihood), skip all
# input logprob computation during prefill since no generation will occur.
if self.is_prefill_only and req.logprob_start_len == len(
req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0
if self.return_logprob:
# Find input logprob token ids.
# First, find a global index within origin_input_ids and slide it by 1
# to compute input logprobs. It is because you need the next token
# to compute input logprobs. E.g., (chunk size 2)
#
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [1, 2]
# extend_input_logprob_token_id = [2, 3]
#
# Note that it can also overflow. In this case, we pad it with 0.
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [3, 4]
# extend_input_logprob_token_id = [4, 0]
global_start_idx, global_end_idx = (
len(req.prefix_indices),
len(req.fill_ids),
)
# Apply logprob_start_len
if global_start_idx < req.logprob_start_len:
global_start_idx = req.logprob_start_len
logprob_token_ids = req.origin_input_ids[
global_start_idx + 1 : global_end_idx + 1
]
extend_input_logprob_token_ids.extend(logprob_token_ids)
# We will need req.extend_input_len - req.extend_logprob_start_len number of
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
extend_input_logprob_token_ids.extend(
[0]
* (
req.extend_input_len
- req.extend_logprob_start_len
- len(logprob_token_ids)
)
)
if self.return_logprob:
extend_input_logprob_token_ids = torch.tensor(
extend_input_logprob_token_ids
)
else:
extend_input_logprob_token_ids = None
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
else None
)
for mm_input in multimodal_inputs:
if mm_input is None:
continue
for mm_item in mm_input.mm_items:
pixel_values = getattr(mm_item, "feature", None)
if isinstance(pixel_values, torch.Tensor):
mm_item.feature = pixel_values.to(self.device, non_blocking=True)
self.multimodal_inputs = multimodal_inputs
self.token_type_ids = token_type_ids_tensor
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def prepare_for_split_prefill(self):
self.prepare_for_extend()
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
self.forward_mode = ForwardMode.SPLIT_PREFILL
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
[
len(r.origin_input_ids) + len(r.output_ids) + delta
for r in running_batch.reqs
]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
page_size = self.token_to_kv_pool_allocator.page_size
requests = (
self.reqs
if selected_indices is None
else [self.reqs[i] for i in selected_indices]
)
if page_size == 1:
return len(requests)
# In the decoding phase, the length of a request's KV cache should be
# the total length of the request minus 1
return (
sum(1 for req in requests if req.seqlen % page_size == 0)
if self.enable_overlap
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
)
def check_decode_mem(
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
):
num_tokens = (
self.new_page_count_next_decode(selected_indices)
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
evict_from_tree_cache(self.tree_cache, num_tokens)
return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = list(range(len(self.reqs)))
# TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter
# requests from the back, so we can only retract from the back.
# TODO(sang): Clean up finish path and support better retract
# policy.
if not server_args.speculative_algorithm:
sorted_indices.sort(
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
retracted_reqs = []
first_iter = True
while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices)
):
if len(sorted_indices) == 1:
# Corner case: only one request left
if self.is_hybrid:
full_available_size = (
self.token_to_kv_pool_allocator.full_available_size()
)
swa_available_size = (
self.token_to_kv_pool_allocator.swa_available_size()
)
assert full_available_size > 0 and swa_available_size > 0, (
f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
)
else:
assert self.token_to_kv_pool_allocator.available_size() > 0, (
f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
)
break
first_iter = False
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
# release memory and don't insert into the tree because we need the space instantly
self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0:
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = (
total_decoded_tokens
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
) / (total_max_new_tokens + 1) # avoid zero division
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio, []
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
# TODO (csy): for preempted requests, we may want to insert into the tree
self.tree_cache.cache_finished_req(req, is_insert=False)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
evict_from_tree_cache(self.tree_cache, num_tokens)
req.reset_for_retract()
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
@property
def is_v2_eagle(self):
# FIXME: finally deprecate is_v2_eagle
return self.enable_overlap and self.spec_algorithm.is_eagle()
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)
if self.is_v2_eagle:
# TODO(spec-v2): all v2 spec should go through this path
draft_input: EagleDraftInput = self.spec_info
draft_input.prepare_for_decode(self)
if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
if self.sampling_info.penalizer_orchestrator.is_required:
if self.enable_overlap:
# TODO: this can be slow, optimize this.
delayed_output_ids = torch.tensor(
[
(
req.output_ids[-1]
if len(req.output_ids)
else req.origin_input_ids[-1]
)
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
delayed_output_ids
)
else:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.output_ids.to(torch.int64)
)
# Update fields
self.input_ids = self.output_ids
self.output_ids = None
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_decode()
# Allocate memory
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
# Update seq_lens after allocation
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1
self.seq_lens_cpu = self.seq_lens_cpu + 1
self.orig_seq_lens = self.orig_seq_lens + 1
else:
# A faster in-place version
self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
def maybe_wait_verify_done(self):
if self.is_v2_eagle:
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.synchronize()
def filter_batch(
self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
):
# FIXME(lsyin): used here to get the correct seq_lens
# The batch has been launched but we need it verified to get correct next batch info
self.maybe_wait_verify_done()
if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
elif chunked_req_to_exclude is None:
chunked_req_to_exclude = []
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] not in chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:
# Filter out all requests
self.reqs = []
return
if len(keep_indices) == len(self.reqs):
# No need to filter
return
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices_device]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
if self.multimodal_inputs is not None:
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
else:
self.top_logprobs_nums = None
self.token_ids_logprobs = None
self.has_stream = any(req.stream for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
has_been_filtered = False
else:
has_been_filtered = True
self.spec_info.filter_batch(
new_indices=keep_indices_device,
has_been_filtered=has_been_filtered,
)
def merge_batch(self, other: "ScheduleBatch"):
# NOTE: in v2 eagle mode, we do not need wait verify here because
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
# 2) other batch is always decode, which is finished in previous step
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info)
# Encoder-decoder infos
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
self.output_ids = torch.cat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.token_ids_logprobs.extend(other.token_ids_logprobs)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
self.token_ids_logprobs.extend([None] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs)
if self.multimodal_inputs is not None:
self.multimodal_inputs.extend(other.multimodal_inputs)
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar
self.return_hidden_states |= other.return_hidden_states
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(
self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
seq_lens_cpu = (
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
)
return ModelWorkerBatch(
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
orig_seq_lens=self.orig_seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
multimodal_inputs=self.multimodal_inputs,
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_ids=[req.lora_id for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
hicache_consumer_index=self.hicache_consumer_index,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if self.spec_info
else CaptureHiddenMode.NULL
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
is_prefill_only=self.is_prefill_only,
blend_info_list=self.blend_info_list,
)
def copy(self):
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
req_pool_indices=self.req_pool_indices,
model_config=self.model_config,
forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
seq_lens_cpu=self.seq_lens_cpu,
enable_overlap=self.enable_overlap,
)
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid:
return (
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
)
else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
f"#req={(len(self.reqs))})"
)
@dataclasses.dataclass
class ModelWorkerBatch:
# The forward mode
forward_mode: ForwardMode
# The input ids
input_ids: torch.Tensor
# The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sequence length tensor on CPU
seq_lens_cpu: Optional[torch.Tensor]
seq_lens_sum: int
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
token_ids_logprobs: Optional[List[List[int]]]
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal
multimodal_inputs: Optional[List[MultimodalInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
encoder_lens_cpu: Optional[List[int]]
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA
lora_ids: Optional[List[str]]
# Sampling info
sampling_info: SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens: Optional[torch.Tensor] = None
# The input Embeds
input_embeds: Optional[torch.Tensor] = None
# For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInput] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
blend_info_list: Optional[BatchBlendInfo] = None

Xet Storage Details

Size:
90.4 kB
·
Xet hash:
c495b634d71ccd28472e4cf8ae5613f8f22382fba2d95674753e645aeec2af87

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.