| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """TokenizerManager is a process that tokenizes the text.""" | |
| import asyncio | |
| import copy | |
| import dataclasses | |
| import logging | |
| import math | |
| import os | |
| import pickle | |
| import signal | |
| import sys | |
| import threading | |
| import time | |
| from collections import deque | |
| from contextlib import nullcontext | |
| from datetime import datetime | |
| from enum import Enum | |
| from http import HTTPStatus | |
| from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union | |
| import fastapi | |
| import orjson | |
| import torch | |
| import uvloop | |
| import zmq | |
| import zmq.asyncio | |
| from fastapi import BackgroundTasks | |
| from sglang.srt.configs.model_config import ModelConfig | |
| from sglang.srt.disaggregation.utils import DisaggregationMode | |
| from sglang.srt.lora.lora_registry import LoRARegistry | |
| from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer | |
| from sglang.srt.managers.disagg_service import start_disagg_service | |
| from sglang.srt.managers.io_struct import ( | |
| AbortReq, | |
| BaseReq, | |
| BatchEmbeddingOutput, | |
| BatchMultimodalOutput, | |
| BatchStrOutput, | |
| BatchTokenIDOutput, | |
| BatchTokenizedEmbeddingReqInput, | |
| BatchTokenizedGenerateReqInput, | |
| ConfigureLoggingReq, | |
| EmbeddingReqInput, | |
| FreezeGCReq, | |
| GenerateReqInput, | |
| GetLoadReqInput, | |
| HealthCheckOutput, | |
| OpenSessionReqOutput, | |
| SessionParams, | |
| TokenizedEmbeddingReqInput, | |
| TokenizedGenerateReqInput, | |
| UpdateWeightFromDiskReqInput, | |
| UpdateWeightFromDiskReqOutput, | |
| WatchLoadUpdateReq, | |
| ) | |
| from sglang.srt.managers.mm_utils import TensorTransportMode | |
| from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors | |
| from sglang.srt.managers.scheduler import is_health_check_generate_req | |
| from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region | |
| from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin | |
| from sglang.srt.metrics.collector import TokenizerMetricsCollector | |
| from sglang.srt.sampling.sampling_params import SamplingParams | |
| from sglang.srt.server_args import PortArgs, ServerArgs | |
| from sglang.srt.speculative.spec_info import SpeculativeAlgorithm | |
| from sglang.srt.tracing.trace import ( | |
| trace_get_proc_propagate_context, | |
| trace_req_finish, | |
| trace_req_start, | |
| trace_slice_end, | |
| trace_slice_start, | |
| ) | |
| from sglang.srt.utils import ( | |
| configure_gc_warning, | |
| dataclass_to_string_truncated, | |
| freeze_gc, | |
| get_bool_env_var, | |
| get_zmq_socket, | |
| kill_process_tree, | |
| ) | |
| from sglang.srt.utils.aio_rwlock import RWLock | |
| from sglang.srt.utils.hf_transformers_utils import ( | |
| get_processor, | |
| get_tokenizer, | |
| get_tokenizer_from_processor, | |
| ) | |
| from sglang.utils import TypeBasedDispatcher, get_exception_traceback | |
| from sglang.srt.utils.cache_blender_info import BlendStyle | |
| _BLEND_STYLES_WITH_SEP = { | |
| BlendStyle.DO_BLEND, | |
| BlendStyle.DO_BLEND_FINISH, | |
| BlendStyle.KVCOMPUTE, | |
| BlendStyle.QCOMPUTE, | |
| } | |
| from sglang.srt.utils.cache_blender import CacheBlender | |
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
| logger = logging.getLogger(__name__) | |
| class ReqState: | |
| """Store the state a request.""" | |
| out_list: List[Dict[Any, Any]] | |
| finished: bool | |
| event: asyncio.Event | |
| obj: Union[GenerateReqInput, EmbeddingReqInput] | |
| # For metrics | |
| created_time: float | |
| finished_time: float = 0.0 | |
| first_token_time: float = 0.0 | |
| last_time: float = 0.0 | |
| last_completion_tokens: int = 1 | |
| # For streaming output | |
| last_output_offset: int = 0 | |
| # For incremental state update. | |
| # TODO(lianmin): do not initialize some lists if not needed. | |
| text: str = "" | |
| output_ids: List[int] = dataclasses.field(default_factory=list) | |
| input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) | |
| input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) | |
| output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) | |
| output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) | |
| input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) | |
| input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) | |
| output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) | |
| output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) | |
| input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) | |
| input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) | |
| output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) | |
| output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) | |
| class TokenizerManager(TokenizerCommunicatorMixin): | |
| """TokenizerManager is a process that tokenizes the text.""" | |
| def __init__( | |
| self, | |
| server_args: ServerArgs, | |
| port_args: PortArgs, | |
| ): | |
| # Parse args | |
| self.server_args = server_args | |
| self.enable_metrics = server_args.enable_metrics | |
| self.log_requests = server_args.log_requests | |
| self.log_requests_level = server_args.log_requests_level | |
| self.preferred_sampling_params = ( | |
| orjson.loads(server_args.preferred_sampling_params) | |
| if server_args.preferred_sampling_params | |
| else None | |
| ) | |
| self.crash_dump_folder = server_args.crash_dump_folder | |
| self.enable_trace = server_args.enable_trace | |
| # Read model args | |
| self.model_path = server_args.model_path | |
| self.served_model_name = server_args.served_model_name | |
| self.model_config = ModelConfig.from_server_args(server_args) | |
| self.is_generation = self.model_config.is_generation | |
| self.is_image_gen = self.model_config.is_image_gen | |
| self.context_len = self.model_config.context_len | |
| self.image_token_id = self.model_config.image_token_id | |
| self.max_req_input_len = None # Will be set later in engine.py | |
| speculative_algorithm = SpeculativeAlgorithm.from_string( | |
| server_args.speculative_algorithm | |
| ) | |
| self.reserve_input_token_num = ( | |
| 0 | |
| if speculative_algorithm.is_none() | |
| else server_args.speculative_num_draft_tokens | |
| ) | |
| # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded) | |
| self.multi_item_delimiter_text = None | |
| if self.model_config.is_multimodal: | |
| import_processors("sglang.srt.multimodal.processors") | |
| try: | |
| _processor = get_processor( | |
| server_args.tokenizer_path, | |
| tokenizer_mode=server_args.tokenizer_mode, | |
| trust_remote_code=server_args.trust_remote_code, | |
| revision=server_args.revision, | |
| use_fast=not server_args.disable_fast_image_processor, | |
| ) | |
| except ValueError as e: | |
| error_message = str(e) | |
| if "does not have a slow version" in error_message: | |
| logger.info( | |
| f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" | |
| ) | |
| _processor = get_processor( | |
| server_args.tokenizer_path, | |
| tokenizer_mode=server_args.tokenizer_mode, | |
| trust_remote_code=server_args.trust_remote_code, | |
| revision=server_args.revision, | |
| use_fast=True, | |
| ) | |
| else: | |
| raise e | |
| transport_mode = _determine_tensor_transport_mode(self.server_args) | |
| # We want to parallelize the image pre-processing so we create an executor for it | |
| # We create mm_processor for any skip_tokenizer_init to make sure we still encode | |
| # images even with skip_tokenizer_init=False. | |
| self.mm_processor = get_mm_processor( | |
| self.model_config.hf_config, server_args, _processor, transport_mode | |
| ) | |
| if server_args.skip_tokenizer_init: | |
| self.tokenizer = self.processor = None | |
| else: | |
| self.processor = _processor | |
| self.tokenizer = get_tokenizer_from_processor(self.processor) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| self._initialize_multi_item_delimiter_text() | |
| else: | |
| self.mm_processor = self.processor = None | |
| if server_args.skip_tokenizer_init: | |
| self.tokenizer = None | |
| else: | |
| self.tokenizer = get_tokenizer( | |
| server_args.tokenizer_path, | |
| tokenizer_mode=server_args.tokenizer_mode, | |
| trust_remote_code=server_args.trust_remote_code, | |
| revision=server_args.revision, | |
| ) | |
| self._initialize_multi_item_delimiter_text() | |
| # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal) | |
| if ( | |
| server_args.enable_dynamic_batch_tokenizer | |
| and not server_args.skip_tokenizer_init | |
| ): | |
| self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer( | |
| self.tokenizer, | |
| max_batch_size=server_args.dynamic_batch_tokenizer_batch_size, | |
| batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout, | |
| ) | |
| else: | |
| self.async_dynamic_batch_tokenizer = None | |
| # Init inter-process communication | |
| context = zmq.asyncio.Context(2) | |
| self.recv_from_detokenizer = get_zmq_socket( | |
| context, zmq.PULL, port_args.tokenizer_ipc_name, True | |
| ) | |
| if self.server_args.tokenizer_worker_num > 1: | |
| # Use tokenizer_worker_ipc_name in multi-tokenizer mode | |
| send_to_scheduler = get_zmq_socket( | |
| context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False | |
| ) | |
| class SenderWrapper: | |
| def send_pyobj(self, obj): | |
| if isinstance(obj, BaseReq): | |
| obj.http_worker_ipc = port_args.tokenizer_ipc_name | |
| send_to_scheduler.send_pyobj(obj) | |
| # Make sure that each request carries the tokenizer_ipc_name for response routing | |
| self.send_to_scheduler = SenderWrapper() | |
| else: | |
| self.send_to_scheduler = get_zmq_socket( | |
| context, zmq.PUSH, port_args.scheduler_input_ipc_name, True | |
| ) | |
| # Request states | |
| self._chosen_loop = None | |
| self.rid_to_state: Dict[str, ReqState] = {} | |
| self.asyncio_tasks = set() | |
| # Health check | |
| self.server_status = ServerStatus.Starting | |
| self.gracefully_exit = False | |
| self.last_receive_tstamp = 0 | |
| # Initial weights status | |
| self.initial_weights_loaded = True | |
| if server_args.checkpoint_engine_wait_weights_before_ready: | |
| self.initial_weights_loaded = False | |
| # Dumping | |
| self.dump_requests_folder = "" # By default do not dump | |
| self.dump_requests_threshold = 1000 | |
| self.dump_request_list: List[Tuple] = [] | |
| self.log_request_metadata = self.get_log_request_metadata() | |
| self.crash_dump_request_list: deque[Tuple] = deque() | |
| self.crash_dump_performed = False # Flag to ensure dump is only called once | |
| # Session | |
| self.session_futures = {} # session_id -> asyncio event | |
| # Weight updates | |
| # The event to notify the weight sync is finished. | |
| self.model_update_lock = RWLock() | |
| self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( | |
| None | |
| ) | |
| self.is_pause = False | |
| self.is_pause_cond = asyncio.Condition() | |
| # LoRA | |
| # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. | |
| # The registry dynamically updates as adapters are loaded / unloaded during runtime. It | |
| # serves as the source of truth for available adapters and maps user-friendly LoRA names | |
| # to internally used unique LoRA IDs. | |
| self.lora_registry = LoRARegistry(self.server_args.lora_paths) | |
| # Lock to serialize LoRA update operations. | |
| # Please note that, unlike `model_update_lock`, this does not block inference, allowing | |
| # LoRA updates and inference to overlap. | |
| self.lora_update_lock = asyncio.Lock() | |
| self.disaggregation_mode = DisaggregationMode( | |
| self.server_args.disaggregation_mode | |
| ) | |
| self.bootstrap_server = start_disagg_service(self.server_args) | |
| # For load balancing | |
| self.current_load = 0 | |
| self.current_load_lock = asyncio.Lock() | |
| # Metrics | |
| if self.enable_metrics: | |
| labels = { | |
| "model_name": self.server_args.served_model_name, | |
| # TODO: Add lora name/path in the future, | |
| } | |
| if server_args.tokenizer_metrics_allowed_custom_labels: | |
| for label in server_args.tokenizer_metrics_allowed_custom_labels: | |
| labels[label] = "" | |
| self.metrics_collector = TokenizerMetricsCollector( | |
| server_args=server_args, | |
| labels=labels, | |
| bucket_time_to_first_token=self.server_args.bucket_time_to_first_token, | |
| bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency, | |
| bucket_inter_token_latency=self.server_args.bucket_inter_token_latency, | |
| collect_tokens_histogram=self.server_args.collect_tokens_histogram, | |
| ) | |
| # Configure GC warning | |
| if self.server_args.gc_warning_threshold_secs > 0.0: | |
| configure_gc_warning(self.server_args.gc_warning_threshold_secs) | |
| self._result_dispatcher = TypeBasedDispatcher( | |
| [ | |
| ( | |
| ( | |
| BatchStrOutput, | |
| BatchEmbeddingOutput, | |
| BatchTokenIDOutput, | |
| BatchMultimodalOutput, | |
| ), | |
| self._handle_batch_output, | |
| ), | |
| (AbortReq, self._handle_abort_req), | |
| (OpenSessionReqOutput, self._handle_open_session_req_output), | |
| ( | |
| UpdateWeightFromDiskReqOutput, | |
| self._handle_update_weights_from_disk_req_output, | |
| ), | |
| ( | |
| FreezeGCReq, | |
| lambda x: None, | |
| ), | |
| # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. | |
| (HealthCheckOutput, lambda x: None), | |
| ] | |
| ) | |
| self.init_communicators(server_args) | |
| async def generate_request( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| request: Optional[fastapi.Request] = None, | |
| ): | |
| created_time = time.time() | |
| self.auto_create_handle_loop() | |
| obj.normalize_batch_and_arguments() | |
| if self.server_args.tokenizer_worker_num > 1: | |
| from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker | |
| assert isinstance(self, TokenizerWorker) | |
| self._attach_multi_http_worker_info(obj) | |
| if self.enable_trace: | |
| self._trace_request_start(obj, created_time) | |
| if self.log_requests: | |
| max_length, skip_names, _ = self.log_request_metadata | |
| logger.info( | |
| f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" | |
| ) | |
| # cache blend | |
| if BlendStyle.parse(obj.blend_style) in _BLEND_STYLES_WITH_SEP: | |
| obj.blend_sep_token, _ = await self._tokenize_texts(obj.separator) | |
| async with self.is_pause_cond: | |
| await self.is_pause_cond.wait_for(lambda: not self.is_pause) | |
| async with self.model_update_lock.reader_lock: | |
| if self.server_args.enable_lora and obj.lora_path: | |
| # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. | |
| obj.lora_id = await self.lora_registry.acquire(obj.lora_path) | |
| if obj.is_single: | |
| tokenized_obj = await self._tokenize_one_request(obj) | |
| state = self._send_one_request(obj, tokenized_obj, created_time) | |
| async for response in self._wait_one_response(obj, state, request): | |
| yield response | |
| else: | |
| async for response in self._handle_batch_request( | |
| obj, request, created_time | |
| ): | |
| yield response | |
| def _detect_input_format( | |
| self, texts: Union[str, List[str]], is_cross_encoder: bool | |
| ) -> str: | |
| """Detect the format of input texts for proper tokenization handling. | |
| Returns: | |
| - "single_string": Regular single text like "Hello world" | |
| - "batch_strings": Regular batch like ["Hello", "World"] | |
| - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]] | |
| """ | |
| if isinstance(texts, str): | |
| return "single_string" | |
| if ( | |
| is_cross_encoder | |
| and len(texts) > 0 | |
| and isinstance(texts[0], list) | |
| and len(texts[0]) == 2 | |
| ): | |
| return "cross_encoder_pairs" | |
| return "batch_strings" | |
| def _prepare_tokenizer_input( | |
| self, texts: Union[str, List[str]], input_format: str | |
| ) -> Union[List[str], List[List[str]]]: | |
| """Prepare input for the tokenizer based on detected format.""" | |
| if input_format == "single_string": | |
| return [texts] # Wrap single string for batch processing | |
| elif input_format == "cross_encoder_pairs": | |
| return texts # Already in correct format: [["query", "doc"]] | |
| else: # batch_strings | |
| return texts # Already in correct format: ["text1", "text2"] | |
| def _extract_tokenizer_results( | |
| self, | |
| input_ids: List[List[int]], | |
| token_type_ids: Optional[List[List[int]]], | |
| input_format: str, | |
| original_batch_size: int, | |
| ) -> Union[ | |
| Tuple[List[int], Optional[List[int]]], | |
| Tuple[List[List[int]], Optional[List[List[int]]]], | |
| ]: | |
| """Extract results from tokenizer output based on input format.""" | |
| # For single inputs (string or single cross-encoder pair), extract first element | |
| if ( | |
| input_format in ["single_string", "cross_encoder_pairs"] | |
| and original_batch_size == 1 | |
| ): | |
| single_input_ids = input_ids[0] if input_ids else [] | |
| single_token_type_ids = token_type_ids[0] if token_type_ids else None | |
| return single_input_ids, single_token_type_ids | |
| # For true batches, return as-is | |
| return input_ids, token_type_ids | |
| async def _tokenize_texts( | |
| self, texts: Union[str, List[str]], is_cross_encoder: bool = False | |
| ) -> Union[ | |
| Tuple[List[int], Optional[List[int]]], | |
| Tuple[List[List[int]], Optional[List[List[int]]]], | |
| ]: | |
| """ | |
| Tokenize text(s) using the appropriate tokenizer strategy. | |
| This method handles multiple input formats and chooses between async dynamic | |
| batch tokenizer (for single texts only) and regular tokenizer. | |
| Args: | |
| texts: Text input in various formats: | |
| Regular cases: | |
| - Single string: "How are you?" | |
| - Batch of strings: ["Hello", "World", "How are you?"] | |
| Cross-encoder cases (sentence pairs for similarity/ranking): | |
| - Single pair: [["query text", "document text"]] | |
| - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]] | |
| is_cross_encoder: Whether to return token_type_ids for cross-encoder models. | |
| Enables proper handling of sentence pairs with segment IDs. | |
| Returns: | |
| Single input cases: | |
| Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids) | |
| Example: ([101, 2129, 102], [0, 0, 0]) for single text | |
| Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair | |
| Batch input cases: | |
| Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids) | |
| Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch | |
| Note: token_type_ids is None unless is_cross_encoder=True. | |
| """ | |
| if not texts or self.tokenizer is None: | |
| raise ValueError("texts cannot be empty and tokenizer must be initialized") | |
| # Step 1: Detect input format and prepare for tokenization | |
| input_format = self._detect_input_format(texts, is_cross_encoder) | |
| tokenizer_input = self._prepare_tokenizer_input(texts, input_format) | |
| original_batch_size = len(texts) if not isinstance(texts, str) else 1 | |
| # Step 2: Set up tokenizer arguments | |
| # Disable automatic special tokens (BOS/EOS) - they are already in the prompt template | |
| tokenizer_kwargs = {"add_special_tokens": False} | |
| if is_cross_encoder: | |
| tokenizer_kwargs["return_token_type_ids"] = True | |
| # Step 3: Choose tokenization strategy | |
| use_async_tokenizer = ( | |
| self.async_dynamic_batch_tokenizer is not None | |
| and input_format == "single_string" | |
| ) | |
| if use_async_tokenizer: | |
| logger.debug("Using async dynamic batch tokenizer for single text") | |
| result = await self.async_dynamic_batch_tokenizer.encode( | |
| tokenizer_input[0], **tokenizer_kwargs | |
| ) | |
| # Convert to batch format for consistency | |
| input_ids = [result["input_ids"]] | |
| token_type_ids = ( | |
| [result["token_type_ids"]] | |
| if is_cross_encoder and result.get("token_type_ids") | |
| else None | |
| ) | |
| else: | |
| logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs") | |
| encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs) | |
| input_ids = encoded["input_ids"] | |
| token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None | |
| # Step 4: Extract results based on input format | |
| return self._extract_tokenizer_results( | |
| input_ids, token_type_ids, input_format, original_batch_size | |
| ) | |
| async def _tokenize_one_request( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| ): | |
| """Tokenize one request.""" | |
| # Tokenize | |
| input_embeds = None | |
| input_text = obj.text | |
| token_type_ids = None | |
| is_cross_encoder_request = ( | |
| isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request | |
| ) | |
| if obj.input_embeds is not None: | |
| if not self.server_args.disable_radix_cache: | |
| raise ValueError( | |
| "input_embeds is provided while disable_radix_cache is False. " | |
| "Please add `--disable-radix-cache` when you launch the server " | |
| "if you want to use input_embeds as inputs." | |
| ) | |
| input_embeds = obj.input_embeds | |
| input_ids = obj.input_ids | |
| elif obj.input_ids is not None: | |
| input_ids = obj.input_ids | |
| else: | |
| if self.tokenizer is None: | |
| raise ValueError( | |
| "The engine initialized with skip_tokenizer_init=True cannot " | |
| "accept text prompts. Please provide input_ids or re-initialize " | |
| "the engine with skip_tokenizer_init=False." | |
| ) | |
| input_ids, token_type_ids = await self._tokenize_texts( | |
| input_text, is_cross_encoder_request | |
| ) | |
| if self.mm_processor and obj.contains_mm_input(): | |
| if obj.image_data is not None and not isinstance(obj.image_data, list): | |
| obj.image_data = [obj.image_data] | |
| if obj.audio_data is not None and not isinstance(obj.audio_data, list): | |
| obj.audio_data = [obj.audio_data] | |
| mm_inputs: Dict = await self.mm_processor.process_mm_data_async( | |
| image_data=obj.image_data, | |
| audio_data=obj.audio_data, | |
| input_text=input_text or input_ids, | |
| request_obj=obj, | |
| max_req_input_len=self.max_req_input_len, | |
| ) | |
| if mm_inputs and "input_ids" in mm_inputs: | |
| input_ids = mm_inputs["input_ids"] | |
| else: | |
| mm_inputs = None | |
| self._validate_one_request(obj, input_ids) | |
| trace_slice_end("tokenize", obj.rid) | |
| return self._create_tokenized_object( | |
| obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids | |
| ) | |
| def _validate_one_request( | |
| self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] | |
| ) -> None: | |
| """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" | |
| # FIXME: unify the length validation logic with the one in the scheduler. | |
| _max_req_len = self.context_len | |
| input_token_num = len(input_ids) if input_ids is not None else 0 | |
| input_token_num += self.reserve_input_token_num | |
| if input_token_num >= self.context_len: | |
| if self.server_args.allow_auto_truncate: | |
| logger.warning( | |
| f"The input ({input_token_num} tokens) is longer than the " | |
| f"model's context length ({self.context_len} tokens). " | |
| "Truncating the input." | |
| ) | |
| del input_ids[_max_req_len:] | |
| input_token_num = len(input_ids) | |
| else: | |
| raise ValueError( | |
| f"The input ({input_token_num} tokens) is longer than the " | |
| f"model's context length ({self.context_len} tokens)." | |
| ) | |
| if isinstance(obj, EmbeddingReqInput) and self.is_generation: | |
| raise ValueError( | |
| "This model does not appear to be an embedding model by default. " | |
| "Please add `--is-embedding` when launching the server or try another model." | |
| ) | |
| # Check total tokens (input + max_new_tokens) | |
| max_new_tokens = obj.sampling_params.get("max_new_tokens") | |
| if ( | |
| max_new_tokens is not None | |
| and (max_new_tokens + input_token_num) >= _max_req_len | |
| ): | |
| if self.server_args.allow_auto_truncate: | |
| logger.warning( | |
| f"Requested token count ({input_token_num} input + {max_new_tokens} new) " | |
| f"exceeds the model's context length ({self.context_len} tokens). " | |
| "Truncating max_new_tokens." | |
| ) | |
| obj.sampling_params["max_new_tokens"] = max( | |
| 0, _max_req_len - input_token_num | |
| ) | |
| else: | |
| total_tokens = max_new_tokens + input_token_num | |
| error_msg = ( | |
| f"Requested token count exceeds the model's maximum context length " | |
| f"of {self.context_len} tokens. You requested a total of {total_tokens} " | |
| f"tokens: {input_token_num} tokens from the input messages and " | |
| f"{max_new_tokens} tokens for the completion. Please reduce the number " | |
| f"of tokens in the input messages or the completion to fit within the limit." | |
| ) | |
| raise ValueError(error_msg) | |
| if isinstance(obj, GenerateReqInput): | |
| if ( | |
| obj.return_hidden_states | |
| and not self.server_args.enable_return_hidden_states | |
| ): | |
| raise ValueError( | |
| "The server is not configured to return the hidden states. " | |
| "Please set `--enable-return-hidden-states` to enable this feature." | |
| ) | |
| if ( | |
| obj.custom_logit_processor | |
| and not self.server_args.enable_custom_logit_processor | |
| ): | |
| raise ValueError( | |
| "The server is not configured to enable custom logit processor. " | |
| "Please set `--enable-custom-logit-processor` to enable this feature." | |
| ) | |
| def _validate_input_ids_in_vocab( | |
| self, input_ids: List[int], vocab_size: int | |
| ) -> None: | |
| if any(id >= vocab_size for id in input_ids): | |
| raise ValueError( | |
| f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})." | |
| ) | |
| def _create_tokenized_object( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| input_text: str, | |
| input_ids: List[int], | |
| input_embeds: Optional[Union[List[float], None]] = None, | |
| mm_inputs: Optional[Dict] = None, | |
| token_type_ids: Optional[List[int]] = None, | |
| ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: | |
| """Create a tokenized request object from common parameters.""" | |
| # Parse sampling parameters | |
| # Note: if there are preferred sampling params, we use them if they are not | |
| # explicitly passed in sampling_params | |
| if self.preferred_sampling_params: | |
| sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params} | |
| else: | |
| sampling_kwargs = obj.sampling_params | |
| sampling_params = SamplingParams(**sampling_kwargs) | |
| sampling_params.normalize(self.tokenizer) | |
| sampling_params.verify(self.model_config.vocab_size) | |
| # Build return object | |
| if isinstance(obj, GenerateReqInput): | |
| session_params = ( | |
| SessionParams(**obj.session_params) if obj.session_params else None | |
| ) | |
| # cache blend | |
| if BlendStyle.parse(obj.blend_style) in _BLEND_STYLES_WITH_SEP: | |
| if ( | |
| mm_inputs is None | |
| and input_embeds is None | |
| and isinstance(input_text, str) | |
| and obj.separator | |
| and obj.separator in input_text | |
| and self.tokenizer is not None | |
| ): | |
| input_text, input_ids, blend_loc_list = ( | |
| CacheBlender.split_text_tokens( | |
| input_text, obj.separator, self.tokenizer | |
| ) | |
| ) | |
| else: | |
| input_text, input_ids, blend_loc_list = CacheBlender.split_tokens( | |
| input_text, input_ids, obj.separator, obj.blend_sep_token | |
| ) | |
| else: | |
| blend_loc_list = None | |
| tokenized_obj = TokenizedGenerateReqInput( | |
| input_text, | |
| input_ids, | |
| mm_inputs, | |
| sampling_params, | |
| obj.return_logprob, | |
| obj.logprob_start_len, | |
| obj.top_logprobs_num, | |
| obj.token_ids_logprob, | |
| obj.stream, | |
| rid=obj.rid, | |
| http_worker_ipc=obj.http_worker_ipc, | |
| bootstrap_host=obj.bootstrap_host, | |
| bootstrap_port=obj.bootstrap_port, | |
| bootstrap_room=obj.bootstrap_room, | |
| lora_id=obj.lora_id, | |
| input_embeds=input_embeds, | |
| session_params=session_params, | |
| custom_logit_processor=obj.custom_logit_processor, | |
| return_hidden_states=obj.return_hidden_states, | |
| data_parallel_rank=obj.data_parallel_rank, | |
| priority=obj.priority, | |
| extra_key=obj.extra_key, | |
| ratio=obj.ratio, | |
| start=obj.start, | |
| method=obj.method, | |
| blend_loc_list=blend_loc_list, | |
| blend_style=obj.blend_style, | |
| attn_start=obj.attn_start, | |
| attn_end=obj.attn_end, | |
| is_contextblend=obj.is_contextblend, | |
| context_cache_source=obj.context_cache_source, | |
| context_n_sink=obj.context_n_sink, | |
| digest_ratio=obj.digest_ratio, | |
| digest_index_method=obj.digest_index_method, | |
| critical_layers=obj.critical_layers, | |
| ssd_cache_path_chunk=obj.ssd_cache_path_chunk, | |
| ssd_cache_path_query=obj.ssd_cache_path_query, | |
| ) | |
| elif isinstance(obj, EmbeddingReqInput): | |
| tokenized_obj = TokenizedEmbeddingReqInput( | |
| input_text, | |
| input_ids, | |
| mm_inputs, | |
| token_type_ids, | |
| sampling_params, | |
| rid=obj.rid, | |
| priority=obj.priority, | |
| http_worker_ipc=obj.http_worker_ipc, | |
| ) | |
| return tokenized_obj | |
| async def _batch_tokenize_and_process( | |
| self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] | |
| ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: | |
| """Handle batch tokenization for text inputs only.""" | |
| logger.debug(f"Starting batch tokenization for {batch_size} text requests") | |
| # If batch does not have text nothing to tokenize | |
| # so lets construct the return object | |
| if not self._batch_has_text(batch_size, obj): | |
| # All requests already have input_ids, no need to tokenize | |
| return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)] | |
| self._validate_batch_tokenization_constraints(batch_size, obj) | |
| # Collect requests and texts | |
| requests = [obj[i] for i in range(batch_size)] | |
| texts = [req.text for req in requests] | |
| # Check if any request is a cross-encoder request | |
| is_cross_encoder_request = any( | |
| isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request | |
| for req in requests | |
| ) | |
| # Batch tokenize all texts using unified method | |
| input_ids_list, token_type_ids_list = await self._tokenize_texts( | |
| texts, is_cross_encoder_request | |
| ) | |
| # Process all requests | |
| tokenized_objs = [] | |
| for i, req in enumerate(requests): | |
| self._validate_one_request(obj[i], input_ids_list[i]) | |
| token_type_ids = ( | |
| token_type_ids_list[i] if token_type_ids_list is not None else None | |
| ) | |
| tokenized_objs.append( | |
| self._create_tokenized_object( | |
| req, req.text, input_ids_list[i], None, None, token_type_ids | |
| ) | |
| ) | |
| trace_slice_end("tokenize", req.rid) | |
| logger.debug(f"Completed batch processing for {batch_size} requests") | |
| return tokenized_objs | |
| def _validate_batch_tokenization_constraints( | |
| self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] | |
| ) -> None: | |
| """Validate constraints for batch tokenization processing.""" | |
| for i in range(batch_size): | |
| if self.is_generation and obj[i].contains_mm_input(): | |
| raise ValueError( | |
| "For multimodal input processing do not set `enable_tokenizer_batch_encode`." | |
| ) | |
| if obj[i].input_ids is not None: | |
| raise ValueError( | |
| "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`." | |
| ) | |
| if obj[i].input_embeds is not None: | |
| raise ValueError( | |
| "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." | |
| ) | |
| def _batch_has_text( | |
| self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] | |
| ) -> bool: | |
| """Check if any request in the batch contains text input.""" | |
| for i in range(batch_size): | |
| if obj[i].text: | |
| return True | |
| elif self.is_generation and obj[i].contains_mm_input(): | |
| return True | |
| return False | |
| def _should_use_batch_tokenization(self, batch_size, requests) -> bool: | |
| """Return True if we should run the tokenizer in batch mode. | |
| Current policy: | |
| - Respect explicit server flag `enable_tokenizer_batch_encode`. | |
| - Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization. | |
| """ | |
| return batch_size > 0 and ( | |
| self.server_args.enable_tokenizer_batch_encode | |
| or not self._batch_has_text(batch_size, requests) | |
| ) | |
| def _send_one_request( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], | |
| created_time: Optional[float] = None, | |
| ): | |
| trace_slice_start("dispatch", obj.rid) | |
| tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) | |
| self.send_to_scheduler.send_pyobj(tokenized_obj) | |
| state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) | |
| self.rid_to_state[obj.rid] = state | |
| trace_slice_end("dispatch", obj.rid, thread_finish_flag=True) | |
| return state | |
| def _send_batch_request( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| tokenized_objs: List[ | |
| Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] | |
| ], | |
| created_time: Optional[float] = None, | |
| ): | |
| """Send a batch of tokenized requests as a single batched request to the scheduler.""" | |
| if isinstance(tokenized_objs[0], TokenizedGenerateReqInput): | |
| batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs) | |
| else: | |
| batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs) | |
| self.send_to_scheduler.send_pyobj(batch_req) | |
| # Create states for each individual request in the batch | |
| for i, tokenized_obj in enumerate(tokenized_objs): | |
| tmp_obj = obj[i] | |
| state = ReqState( | |
| [], False, asyncio.Event(), tmp_obj, created_time=created_time | |
| ) | |
| self.rid_to_state[tmp_obj.rid] = state | |
| async def _wait_one_response( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| state: ReqState, | |
| request: Optional[fastapi.Request] = None, | |
| ): | |
| """Wait for the response of one request.""" | |
| while True: | |
| try: | |
| await asyncio.wait_for(state.event.wait(), timeout=4) | |
| except asyncio.TimeoutError: | |
| if ( | |
| request is not None | |
| and not obj.background | |
| and await request.is_disconnected() | |
| ): | |
| # Abort the request for disconnected requests (non-streaming, waiting queue) | |
| self.abort_request(obj.rid) | |
| # Use exception to kill the whole call stack and asyncio task | |
| raise ValueError( | |
| f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}" | |
| ) | |
| continue | |
| out = state.out_list[-1] | |
| state.out_list = [] | |
| if state.finished: | |
| if self.log_requests: | |
| max_length, skip_names, out_skip_names = self.log_request_metadata | |
| if self.model_config.is_multimodal_gen: | |
| msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" | |
| else: | |
| msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" | |
| logger.info(msg) | |
| # Check if this was an abort/error created by scheduler | |
| if isinstance(out["meta_info"].get("finish_reason"), dict): | |
| finish_reason = out["meta_info"]["finish_reason"] | |
| if ( | |
| finish_reason.get("type") == "abort" | |
| and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST | |
| ): | |
| raise ValueError(finish_reason["message"]) | |
| if finish_reason.get("type") == "abort" and finish_reason.get( | |
| "status_code" | |
| ) in ( | |
| HTTPStatus.SERVICE_UNAVAILABLE, | |
| HTTPStatus.INTERNAL_SERVER_ERROR, | |
| ): | |
| # This is an abort request initiated by scheduler. | |
| # Delete the key to prevent resending abort request to the scheduler and | |
| # to ensure aborted request state is cleaned up. | |
| if state.obj.rid in self.rid_to_state: | |
| del self.rid_to_state[state.obj.rid] | |
| # Mark ongoing LoRA request as finished. | |
| if self.server_args.enable_lora and state.obj.lora_path: | |
| await self.lora_registry.release(state.obj.lora_id) | |
| raise fastapi.HTTPException( | |
| status_code=finish_reason["status_code"], | |
| detail=finish_reason["message"], | |
| ) | |
| yield out | |
| break | |
| state.event.clear() | |
| if obj.stream: | |
| yield out | |
| else: | |
| if ( | |
| request is not None | |
| and not obj.background | |
| and await request.is_disconnected() | |
| ): | |
| # Abort the request for disconnected requests (non-streaming, running) | |
| self.abort_request(obj.rid) | |
| # Use exception to kill the whole call stack and asyncio task | |
| raise ValueError( | |
| f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}" | |
| ) | |
| async def _handle_batch_request( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| request: Optional[fastapi.Request] = None, | |
| created_time: Optional[float] = None, | |
| ): | |
| batch_size = obj.batch_size | |
| generators = [] | |
| rids = [] | |
| if getattr(obj, "parallel_sample_num", 1) == 1: | |
| if self._should_use_batch_tokenization(batch_size, obj): | |
| tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) | |
| self._send_batch_request(obj, tokenized_objs, created_time) | |
| # Set up generators for each request in the batch | |
| for i in range(batch_size): | |
| tmp_obj = obj[i] | |
| generators.append( | |
| self._wait_one_response( | |
| tmp_obj, self.rid_to_state[tmp_obj.rid], request | |
| ) | |
| ) | |
| rids.append(tmp_obj.rid) | |
| else: | |
| # Sequential tokenization and processing | |
| with ( | |
| input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler) | |
| if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN") | |
| else nullcontext() | |
| ): | |
| for i in range(batch_size): | |
| tmp_obj = obj[i] | |
| tokenized_obj = await self._tokenize_one_request(tmp_obj) | |
| state = self._send_one_request( | |
| tmp_obj, tokenized_obj, created_time | |
| ) | |
| generators.append( | |
| self._wait_one_response(tmp_obj, state, request) | |
| ) | |
| rids.append(tmp_obj.rid) | |
| else: | |
| # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. | |
| if batch_size > 128: | |
| logger.warning( | |
| "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. " | |
| "The performance might be better if you just duplicate the requests n times or use " | |
| "many threads to send them one by one with parallel sampling (n > 1)." | |
| ) | |
| # Tokenize all requests | |
| objs = [obj[i] for i in range(batch_size)] | |
| tokenized_objs = await asyncio.gather( | |
| *(self._tokenize_one_request(obj) for obj in objs) | |
| ) | |
| # Cache the common prefix for parallel sampling | |
| for i in range(batch_size): | |
| tmp_obj = copy.copy(objs[i]) | |
| tokenized_obj = copy.copy(tokenized_objs[i]) | |
| tokenized_obj.rid = tmp_obj.regenerate_rid() | |
| tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) | |
| tokenized_obj.sampling_params.max_new_tokens = 0 | |
| tokenized_obj.stream = False | |
| state = self._send_one_request(tmp_obj, tokenized_obj, created_time) | |
| await self._wait_one_response(tmp_obj, state, request).__anext__() | |
| # Expand requests, assign new rids for them, and send them | |
| for i in range(batch_size): | |
| for _ in range(obj.parallel_sample_num): | |
| tmp_obj = copy.copy(objs[i]) | |
| tokenized_obj = copy.copy(tokenized_objs[i]) | |
| tokenized_obj.rid = tmp_obj.regenerate_rid() | |
| state = self._send_one_request(tmp_obj, tokenized_obj, created_time) | |
| generators.append(self._wait_one_response(tmp_obj, state, request)) | |
| rids.append(tmp_obj.rid) | |
| # Wait for all requests | |
| is_stream = hasattr(obj, "stream") and obj.stream | |
| if not is_stream: | |
| outputs = await asyncio.gather(*(gen.__anext__() for gen in generators)) | |
| yield outputs | |
| else: | |
| rid_to_index = {rid: i for i, rid in enumerate(rids)} | |
| task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} | |
| while task_map: | |
| done, _ = await asyncio.wait( | |
| task_map.keys(), return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| for task in done: | |
| gen = task_map.pop(task) | |
| try: | |
| result = task.result() | |
| result["index"] = rid_to_index[result["meta_info"]["id"]] | |
| yield result | |
| new_task = asyncio.create_task(gen.__anext__()) | |
| task_map[new_task] = gen | |
| except StopAsyncIteration: | |
| pass | |
| def abort_request(self, rid: str = "", abort_all: bool = False): | |
| if not abort_all and rid not in self.rid_to_state: | |
| return | |
| req = AbortReq(rid=rid, abort_all=abort_all) | |
| self.send_to_scheduler.send_pyobj(req) | |
| if self.enable_metrics: | |
| # TODO: also use custom_labels from the request | |
| self.metrics_collector.observe_one_aborted_request( | |
| self.metrics_collector.labels | |
| ) | |
| async def pause_generation(self): | |
| async with self.is_pause_cond: | |
| self.is_pause = True | |
| self.abort_request(abort_all=True) | |
| async def continue_generation(self): | |
| async with self.is_pause_cond: | |
| self.is_pause = False | |
| self.is_pause_cond.notify_all() | |
| async def update_weights_from_disk( | |
| self, | |
| obj: UpdateWeightFromDiskReqInput, | |
| request: Optional[fastapi.Request] = None, | |
| ) -> Tuple[bool, str]: | |
| self.auto_create_handle_loop() | |
| # default the load format to the server_args | |
| if obj.load_format is None: | |
| obj.load_format = self.server_args.load_format | |
| logger.info("Start update_weights. Load format=%s", obj.load_format) | |
| if obj.abort_all_requests: | |
| self.abort_request(abort_all=True) | |
| if True: # Keep this redundant check to simplify some internal code sync | |
| # Hold the lock if it is not async. This means that weight sync | |
| # cannot run while requests are in progress. | |
| async with self.model_update_lock.writer_lock: | |
| return await self._wait_for_model_update_from_disk(obj) | |
| async def _wait_for_model_update_from_disk( | |
| self, obj: UpdateWeightFromDiskReqInput | |
| ) -> Tuple[bool, str]: | |
| self.send_to_scheduler.send_pyobj(obj) | |
| self.model_update_result = asyncio.Future() | |
| if self.server_args.dp_size == 1: | |
| result = await self.model_update_result | |
| if result.success: | |
| self.served_model_name = obj.model_path | |
| self.server_args.model_path = obj.model_path | |
| self.server_args.load_format = obj.load_format | |
| self.model_path = obj.model_path | |
| return result.success, result.message, result.num_paused_requests | |
| else: # self.server_args.dp_size > 1 | |
| self.model_update_tmp = [] | |
| result = await self.model_update_result | |
| all_success = all([r.success for r in result]) | |
| if all_success is True: | |
| self.server_args.model_path = obj.model_path | |
| self.server_args.load_format = obj.load_format | |
| self.model_path = obj.model_path | |
| all_message = [r.message for r in result] | |
| all_message = " | ".join(all_message) | |
| all_paused_requests = [r.num_paused_requests for r in result] | |
| return all_success, all_message, all_paused_requests | |
| def configure_logging(self, obj: ConfigureLoggingReq): | |
| if obj.log_requests is not None: | |
| self.log_requests = obj.log_requests | |
| if obj.log_requests_level is not None: | |
| self.log_requests_level = obj.log_requests_level | |
| if obj.dump_requests_folder is not None: | |
| self.dump_requests_folder = obj.dump_requests_folder | |
| if obj.dump_requests_threshold is not None: | |
| self.dump_requests_threshold = obj.dump_requests_threshold | |
| if obj.crash_dump_folder is not None: | |
| self.crash_dump_folder = obj.crash_dump_folder | |
| logging.info(f"Config logging: {obj=}") | |
| self.log_request_metadata = self.get_log_request_metadata() | |
| async def freeze_gc(self): | |
| """Send a freeze_gc message to the scheduler first, then freeze locally.""" | |
| self.send_to_scheduler.send_pyobj(FreezeGCReq()) | |
| freeze_gc("Tokenizer Manager") | |
| return None | |
| def create_abort_task(self, obj: GenerateReqInput): | |
| # Abort the request if the client is disconnected. | |
| async def abort_request(): | |
| await asyncio.sleep(2) | |
| if obj.is_single: | |
| self.abort_request(obj.rid) | |
| else: | |
| for rid in obj.rid: | |
| self.abort_request(rid) | |
| background_tasks = BackgroundTasks() | |
| background_tasks.add_task(abort_request) | |
| return background_tasks | |
| def auto_create_handle_loop(self): | |
| if self._chosen_loop is not None: | |
| assert asyncio.get_event_loop() == self._chosen_loop, ( | |
| f"Please ensure only one event loop is ever used with SGLang. Previous loop: {self._chosen_loop}, current loop: {asyncio.get_event_loop()}" | |
| ) | |
| return | |
| loop = asyncio.get_event_loop() | |
| self._chosen_loop = loop | |
| self.asyncio_tasks.add( | |
| loop.create_task(print_exception_wrapper(self.handle_loop)) | |
| ) | |
| self.event_loop = loop | |
| # We cannot add signal handler when the tokenizer manager is not in | |
| # the main thread due to the CPython limitation. | |
| if threading.current_thread() is threading.main_thread(): | |
| signal_handler = SignalHandler(self) | |
| loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler) | |
| # Update the signal handler for the process. It overrides the sigquit handler in the launch phase. | |
| loop.add_signal_handler( | |
| signal.SIGQUIT, signal_handler.running_phase_sigquit_handler | |
| ) | |
| else: | |
| logger.warning( | |
| "Signal handler is not added because the tokenizer manager is " | |
| "not in the main thread. This disables graceful shutdown of the " | |
| "tokenizer manager when SIGTERM is received." | |
| ) | |
| self.asyncio_tasks.add( | |
| loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) | |
| ) | |
| self.asyncio_tasks.add( | |
| loop.create_task(print_exception_wrapper(self.watch_load_thread)) | |
| ) | |
| def dump_requests_before_crash(self): | |
| if self.crash_dump_performed: | |
| logger.info( | |
| "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping." | |
| ) | |
| return | |
| if not self.crash_dump_folder: | |
| return | |
| logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}") | |
| self.crash_dump_performed = True | |
| # Check if NFS directory is available | |
| # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0] | |
| # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access( | |
| # expected_nfs_dir, os.W_OK | |
| # ) | |
| use_nfs_dir = False | |
| if not use_nfs_dir: | |
| logger.error( | |
| f"Expected NFS directory is not available or writable. Uploading to GCS." | |
| ) | |
| data_to_dump = [] | |
| if self.crash_dump_request_list: | |
| data_to_dump.extend(self.crash_dump_request_list) | |
| # Add unfinished requests from rid_to_state | |
| unfinished_requests = [] | |
| for rid, state in self.rid_to_state.items(): | |
| if not state.finished: | |
| unfinished_requests.append( | |
| ( | |
| state.obj, | |
| state.out_list[-1] if state.out_list else {}, | |
| state.created_time, | |
| time.time(), | |
| ) | |
| ) | |
| if unfinished_requests: | |
| data_to_dump.extend(unfinished_requests) | |
| if not data_to_dump: | |
| return | |
| object_name = f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl" | |
| filename = os.path.join( | |
| self.crash_dump_folder, | |
| os.getenv("HOSTNAME", None), | |
| object_name, | |
| ) | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| # Include server_args in the dump | |
| data_to_dump_with_server_args = { | |
| "server_args": self.server_args, | |
| "requests": data_to_dump, | |
| } | |
| with open(filename, "wb") as f: | |
| pickle.dump(data_to_dump_with_server_args, f) | |
| logger.error( | |
| f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}" | |
| ) | |
| def _upload_file_to_gcs(bucket_name, source_file_path, object_name): | |
| from google.cloud import storage | |
| client = storage.Client() | |
| bucket = client.bucket(bucket_name) | |
| blob = bucket.blob(object_name) | |
| blob.upload_from_filename(source_file_path, if_generation_match=0) | |
| logger.error( | |
| f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}" | |
| ) | |
| if not use_nfs_dir: | |
| _upload_file_to_gcs( | |
| "sglang_crash_dump", | |
| filename, | |
| os.getenv("HOSTNAME", None) + "/" + object_name, | |
| ) | |
| async def sigterm_watchdog(self): | |
| while not self.gracefully_exit: | |
| await asyncio.sleep(5) | |
| # Drain requests | |
| while True: | |
| remain_num_req = len(self.rid_to_state) | |
| remaining_rids = list(self.rid_to_state.keys()) | |
| if self.server_status == ServerStatus.UnHealthy: | |
| # if health check failed, we should exit immediately | |
| logger.error( | |
| "Signal SIGTERM received while health check failed. Force exiting." | |
| ) | |
| self.dump_requests_before_crash() | |
| break | |
| elif get_bool_env_var("SGL_FORCE_SHUTDOWN"): | |
| # if force shutdown flag set, exit immediately | |
| logger.error( | |
| "Signal SIGTERM received while force shutdown flag set. Force exiting." | |
| ) | |
| break | |
| logger.info( | |
| f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}." | |
| ) | |
| if remain_num_req > 0: | |
| await asyncio.sleep(5) | |
| else: | |
| self.dump_requests_before_crash() | |
| break | |
| kill_process_tree(os.getpid(), include_parent=True) | |
| sys.exit(0) | |
| async def handle_loop(self): | |
| """The event loop that handles requests""" | |
| while True: | |
| recv_obj = await self.recv_from_detokenizer.recv_pyobj() | |
| self._result_dispatcher(recv_obj) | |
| self.last_receive_tstamp = time.time() | |
| def _handle_batch_output( | |
| self, | |
| recv_obj: Union[ | |
| BatchStrOutput, | |
| BatchEmbeddingOutput, | |
| BatchMultimodalOutput, | |
| BatchTokenIDOutput, | |
| ], | |
| ): | |
| for i, rid in enumerate(recv_obj.rids): | |
| state = self.rid_to_state.get(rid, None) | |
| if state is None: | |
| logger.error( | |
| f"Received output for {rid=} but the state was deleted in TokenizerManager." | |
| ) | |
| continue | |
| # Build meta_info and return value | |
| meta_info = { | |
| "id": rid, | |
| "finish_reason": recv_obj.finished_reasons[i], | |
| "prompt_tokens": recv_obj.prompt_tokens[i], | |
| "weight_version": self.server_args.weight_version, | |
| } | |
| if getattr(state.obj, "return_logprob", False): | |
| self.convert_logprob_style( | |
| meta_info, | |
| state, | |
| state.obj.top_logprobs_num, | |
| state.obj.token_ids_logprob, | |
| state.obj.return_text_in_logprobs | |
| and not self.server_args.skip_tokenizer_init, | |
| recv_obj, | |
| i, | |
| ) | |
| if not isinstance(recv_obj, BatchEmbeddingOutput): | |
| meta_info.update( | |
| { | |
| "completion_tokens": recv_obj.completion_tokens[i], | |
| "cached_tokens": recv_obj.cached_tokens[i], | |
| } | |
| ) | |
| if getattr(recv_obj, "output_hidden_states", None): | |
| meta_info["hidden_states"] = recv_obj.output_hidden_states[i] | |
| if isinstance(recv_obj, BatchStrOutput): | |
| state.text += recv_obj.output_strs[i] | |
| if state.obj.stream: | |
| state.output_ids.extend(recv_obj.output_ids[i]) | |
| output_token_ids = state.output_ids[state.last_output_offset :] | |
| state.last_output_offset = len(state.output_ids) | |
| else: | |
| state.output_ids.extend(recv_obj.output_ids[i]) | |
| output_token_ids = state.output_ids.copy() | |
| out_dict = { | |
| "text": state.text, | |
| "output_ids": output_token_ids, | |
| "meta_info": meta_info, | |
| } | |
| elif isinstance(recv_obj, BatchTokenIDOutput): | |
| if self.server_args.stream_output and state.obj.stream: | |
| state.output_ids.extend(recv_obj.output_ids[i]) | |
| output_token_ids = state.output_ids[state.last_output_offset :] | |
| state.last_output_offset = len(state.output_ids) | |
| else: | |
| state.output_ids.extend(recv_obj.output_ids[i]) | |
| output_token_ids = state.output_ids.copy() | |
| out_dict = { | |
| "output_ids": output_token_ids, | |
| "meta_info": meta_info, | |
| } | |
| elif isinstance(recv_obj, BatchMultimodalOutput): | |
| raise NotImplementedError("BatchMultimodalOut not implemented") | |
| else: | |
| assert isinstance(recv_obj, BatchEmbeddingOutput) | |
| out_dict = { | |
| "embedding": recv_obj.embeddings[i], | |
| "meta_info": meta_info, | |
| } | |
| state.finished = recv_obj.finished_reasons[i] is not None | |
| if state.finished: | |
| if self.server_args.speculative_algorithm: | |
| self._calculate_spec_decoding_metrics(meta_info, recv_obj, i) | |
| state.finished_time = time.time() | |
| meta_info["e2e_latency"] = state.finished_time - state.created_time | |
| trace_req_finish(rid, ts=int(state.finished_time * 1e9)) | |
| del self.rid_to_state[rid] | |
| # Mark ongoing LoRA request as finished. | |
| if self.server_args.enable_lora and state.obj.lora_path: | |
| asyncio.create_task(self.lora_registry.release(state.obj.lora_id)) | |
| state.out_list.append(out_dict) | |
| state.event.set() | |
| # Log metrics and dump | |
| if self.enable_metrics and state.obj.log_metrics: | |
| self.collect_metrics(state, recv_obj, i) | |
| if self.dump_requests_folder and state.finished and state.obj.log_metrics: | |
| self.dump_requests(state, out_dict) | |
| if self.crash_dump_folder and state.finished and state.obj.log_metrics: | |
| self.record_request_for_crash_dump(state, out_dict) | |
| def convert_logprob_style( | |
| self, | |
| meta_info: dict, | |
| state: ReqState, | |
| top_logprobs_num: int, | |
| token_ids_logprob: List[int], | |
| return_text_in_logprobs: bool, | |
| recv_obj: BatchStrOutput, | |
| recv_obj_index: int, | |
| ): | |
| if recv_obj.input_token_logprobs_val is None: | |
| return | |
| if len(recv_obj.input_token_logprobs_val) > 0: | |
| state.input_token_logprobs_val.extend( | |
| recv_obj.input_token_logprobs_val[recv_obj_index] | |
| ) | |
| state.input_token_logprobs_idx.extend( | |
| recv_obj.input_token_logprobs_idx[recv_obj_index] | |
| ) | |
| state.output_token_logprobs_val.extend( | |
| recv_obj.output_token_logprobs_val[recv_obj_index] | |
| ) | |
| state.output_token_logprobs_idx.extend( | |
| recv_obj.output_token_logprobs_idx[recv_obj_index] | |
| ) | |
| meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( | |
| state.input_token_logprobs_val, | |
| state.input_token_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( | |
| state.output_token_logprobs_val, | |
| state.output_token_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| if top_logprobs_num > 0: | |
| if len(recv_obj.input_top_logprobs_val) > 0: | |
| state.input_top_logprobs_val.extend( | |
| recv_obj.input_top_logprobs_val[recv_obj_index] | |
| ) | |
| state.input_top_logprobs_idx.extend( | |
| recv_obj.input_top_logprobs_idx[recv_obj_index] | |
| ) | |
| state.output_top_logprobs_val.extend( | |
| recv_obj.output_top_logprobs_val[recv_obj_index] | |
| ) | |
| state.output_top_logprobs_idx.extend( | |
| recv_obj.output_top_logprobs_idx[recv_obj_index] | |
| ) | |
| meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( | |
| state.input_top_logprobs_val, | |
| state.input_top_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( | |
| state.output_top_logprobs_val, | |
| state.output_top_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| if token_ids_logprob is not None: | |
| if len(recv_obj.input_token_ids_logprobs_val) > 0: | |
| state.input_token_ids_logprobs_val.extend( | |
| recv_obj.input_token_ids_logprobs_val[recv_obj_index] | |
| ) | |
| state.input_token_ids_logprobs_idx.extend( | |
| recv_obj.input_token_ids_logprobs_idx[recv_obj_index] | |
| ) | |
| state.output_token_ids_logprobs_val.extend( | |
| recv_obj.output_token_ids_logprobs_val[recv_obj_index] | |
| ) | |
| state.output_token_ids_logprobs_idx.extend( | |
| recv_obj.output_token_ids_logprobs_idx[recv_obj_index] | |
| ) | |
| meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( | |
| state.input_token_ids_logprobs_val, | |
| state.input_token_ids_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| meta_info["output_token_ids_logprobs"] = ( | |
| self.detokenize_top_logprobs_tokens( | |
| state.output_token_ids_logprobs_val, | |
| state.output_token_ids_logprobs_idx, | |
| return_text_in_logprobs, | |
| ) | |
| ) | |
| def detokenize_logprob_tokens( | |
| self, | |
| token_logprobs_val: List[float], | |
| token_logprobs_idx: List[int], | |
| decode_to_text: bool, | |
| ): | |
| if not decode_to_text: | |
| return [ | |
| (logprob, token_id, None) | |
| for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) | |
| ] | |
| else: | |
| assert self.tokenizer is not None | |
| token_texts = self.tokenizer.batch_decode(token_logprobs_idx) | |
| return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) | |
| def detokenize_top_logprobs_tokens( | |
| self, | |
| token_logprobs_val: List[float], | |
| token_logprobs_idx: List[int], | |
| decode_to_text: bool, | |
| ): | |
| # TODO: The current implementation only batches the detokenization for top-k tokens per single position. | |
| # We should batch all top-k tokens in all positions. | |
| ret = [] | |
| for i in range(len(token_logprobs_val)): | |
| if token_logprobs_val[i]: | |
| ret.append( | |
| self.detokenize_logprob_tokens( | |
| token_logprobs_val[i], token_logprobs_idx[i], decode_to_text | |
| ) | |
| ) | |
| else: | |
| ret.append(None) | |
| return ret | |
| def _calculate_spec_decoding_metrics( | |
| self, | |
| meta_info: Dict[str, Any], | |
| recv_obj: Union[ | |
| BatchStrOutput, | |
| BatchEmbeddingOutput, | |
| BatchMultimodalOutput, | |
| BatchTokenIDOutput, | |
| ], | |
| i: int, | |
| ) -> None: | |
| """Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics.""" | |
| meta_info["spec_accept_rate"] = 0.0 | |
| meta_info["spec_accept_length"] = 0 | |
| meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] | |
| if ( | |
| recv_obj.spec_verify_ct[i] > 0 | |
| and self.server_args.speculative_num_steps is not None | |
| and not isinstance(recv_obj, BatchEmbeddingOutput) | |
| and hasattr(recv_obj, "spec_accepted_tokens") | |
| # Checks that `spec_accepted_tokens[i]` will exist. | |
| and len(recv_obj.spec_accepted_tokens) > i | |
| ): | |
| total_draft_tokens = ( | |
| recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps | |
| ) | |
| accepted_tokens = recv_obj.spec_accepted_tokens[i] | |
| # Calculate per-request acceptance rate and average acceptance length. | |
| if total_draft_tokens > 0: | |
| # Calculate acceptance rate: accepted / (steps * lookahead) | |
| meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens | |
| meta_info["spec_accept_length"] = ( | |
| recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i] | |
| ) | |
| def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): | |
| completion_tokens = ( | |
| recv_obj.completion_tokens[i] | |
| if getattr(recv_obj, "completion_tokens", None) | |
| else 0 | |
| ) | |
| custom_labels = getattr(state.obj, "custom_labels", None) | |
| labels = ( | |
| {**self.metrics_collector.labels, **custom_labels} | |
| if custom_labels | |
| else self.metrics_collector.labels | |
| ) | |
| if ( | |
| state.first_token_time == 0.0 | |
| and self.disaggregation_mode != DisaggregationMode.PREFILL | |
| ): | |
| state.first_token_time = state.last_time = time.time() | |
| state.last_completion_tokens = completion_tokens | |
| self.metrics_collector.observe_time_to_first_token( | |
| labels, state.first_token_time - state.created_time | |
| ) | |
| else: | |
| num_new_tokens = completion_tokens - state.last_completion_tokens | |
| if num_new_tokens: | |
| new_time = time.time() | |
| interval = new_time - state.last_time | |
| self.metrics_collector.observe_inter_token_latency( | |
| labels, | |
| interval, | |
| num_new_tokens, | |
| ) | |
| state.last_time = new_time | |
| state.last_completion_tokens = completion_tokens | |
| if state.finished: | |
| has_grammar = ( | |
| state.obj.sampling_params.get("json_schema", None) | |
| or state.obj.sampling_params.get("regex", None) | |
| or state.obj.sampling_params.get("ebnf", None) | |
| or state.obj.sampling_params.get("structural_tag", None) | |
| ) | |
| self.metrics_collector.observe_one_finished_request( | |
| labels, | |
| recv_obj.prompt_tokens[i], | |
| completion_tokens, | |
| recv_obj.cached_tokens[i], | |
| state.finished_time - state.created_time, | |
| has_grammar, | |
| ) | |
| def dump_requests(self, state: ReqState, out_dict: dict): | |
| self.dump_request_list.append( | |
| (state.obj, out_dict, state.created_time, time.time()) | |
| ) | |
| if len(self.dump_request_list) >= self.dump_requests_threshold: | |
| filename = os.path.join( | |
| self.dump_requests_folder, | |
| datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", | |
| ) | |
| self._dump_data_to_file( | |
| data_list=self.dump_request_list, | |
| filename=filename, | |
| log_message=f"Dump {len(self.dump_request_list)} requests to {filename}", | |
| ) | |
| self.dump_request_list = [] | |
| def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): | |
| current_time = time.time() | |
| self.crash_dump_request_list.append( | |
| (state.obj, out_dict, state.created_time, current_time) | |
| ) | |
| # Remove requests older than 5 minutes based on finish time | |
| while ( | |
| self.crash_dump_request_list | |
| and current_time - self.crash_dump_request_list[0][3] >= 300 | |
| ): | |
| self.crash_dump_request_list.popleft() | |
| def _dump_data_to_file( | |
| self, data_list: List[Tuple], filename: str, log_message: str | |
| ): | |
| logger.info(log_message) | |
| to_dump_with_server_args = { | |
| "server_args": self.server_args, | |
| "requests": data_list.copy(), | |
| } | |
| def background_task(): | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| with open(filename, "wb") as f: | |
| pickle.dump(to_dump_with_server_args, f) | |
| asyncio.create_task(asyncio.to_thread(background_task)) | |
| def _handle_abort_req(self, recv_obj: AbortReq): | |
| if is_health_check_generate_req(recv_obj): | |
| return | |
| state = self.rid_to_state[recv_obj.rid] | |
| state.finished = True | |
| if recv_obj.finished_reason: | |
| out = { | |
| "meta_info": { | |
| "id": recv_obj.rid, | |
| "finish_reason": recv_obj.finished_reason, | |
| }, | |
| } | |
| else: | |
| out = { | |
| "text": "", | |
| "meta_info": { | |
| "id": recv_obj.rid, | |
| "finish_reason": { | |
| "type": "abort", | |
| "message": "Abort before prefill", | |
| }, | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| }, | |
| } | |
| state.out_list.append(out) | |
| state.event.set() | |
| def _handle_open_session_req_output(self, recv_obj): | |
| self.session_futures[recv_obj.session_id].set_result( | |
| recv_obj.session_id if recv_obj.success else None | |
| ) | |
| def _handle_update_weights_from_disk_req_output(self, recv_obj): | |
| if self.server_args.dp_size == 1: | |
| self.model_update_result.set_result(recv_obj) | |
| else: # self.server_args.dp_size > 1 | |
| self.model_update_tmp.append(recv_obj) | |
| # set future if the all results are received | |
| if len(self.model_update_tmp) == self.server_args.dp_size: | |
| self.model_update_result.set_result(self.model_update_tmp) | |
| def _initialize_multi_item_delimiter_text(self): | |
| """Initialize multi-item delimiter text from token ID after tokenizer is loaded.""" | |
| if ( | |
| hasattr(self.server_args, "multi_item_scoring_delimiter") | |
| and self.server_args.multi_item_scoring_delimiter is not None | |
| and self.tokenizer is not None | |
| ): | |
| try: | |
| self.multi_item_delimiter_text = self.tokenizer.decode( | |
| [self.server_args.multi_item_scoring_delimiter], | |
| skip_special_tokens=False, | |
| ) | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}" | |
| ) | |
| self.multi_item_delimiter_text = None | |
| def _build_multi_item_token_sequence( | |
| self, query: List[int], items: List[List[int]], delimiter_token_id: int | |
| ) -> List[int]: | |
| """ | |
| Build a single token sequence for multi-item scoring. | |
| Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter> | |
| Args: | |
| query: Query token IDs | |
| items: List of item token ID sequences | |
| delimiter_token_id: Token ID to use as delimiter | |
| Returns: | |
| Combined token sequence | |
| """ | |
| combined_sequence = query[:] # Start with query | |
| for item in items: | |
| combined_sequence.append(delimiter_token_id) # Add delimiter | |
| combined_sequence.extend(item) # Add item tokens | |
| # Add final delimiter after the last item for logprob extraction | |
| combined_sequence.append(delimiter_token_id) | |
| return combined_sequence | |
| def _extract_logprobs_for_tokens( | |
| self, logprobs_data: List, label_token_ids: List[int] | |
| ) -> Dict[int, float]: | |
| """ | |
| Extract logprobs for specified token IDs from logprobs data. | |
| Args: | |
| logprobs_data: List of (logprob, token_id, text) tuples | |
| label_token_ids: Token IDs to extract logprobs for | |
| Returns: | |
| Dictionary mapping token_id to logprob | |
| """ | |
| logprobs = {} | |
| if logprobs_data: | |
| for logprob, token_id, _ in logprobs_data: | |
| if token_id in label_token_ids: | |
| logprobs[token_id] = logprob | |
| return logprobs | |
| def _convert_logprobs_to_scores( | |
| self, | |
| logprobs: Dict[int, float], | |
| label_token_ids: List[int], | |
| apply_softmax: bool, | |
| ) -> List[float]: | |
| """ | |
| Convert logprobs dictionary to ordered score list. | |
| Args: | |
| logprobs: Dictionary mapping token_id to logprob | |
| label_token_ids: Token IDs in desired order | |
| apply_softmax: Whether to apply softmax normalization | |
| Returns: | |
| List of scores in the same order as label_token_ids | |
| """ | |
| score_list = [ | |
| logprobs.get(token_id, float("-inf")) for token_id in label_token_ids | |
| ] | |
| if apply_softmax: | |
| score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist() | |
| else: | |
| # Convert logprobs to probabilities if not using softmax | |
| score_list = [ | |
| math.exp(x) if x != float("-inf") else 0.0 for x in score_list | |
| ] | |
| return score_list | |
| def _process_multi_item_scoring_results( | |
| self, | |
| results: Any, | |
| items: List, | |
| label_token_ids: List[int], | |
| apply_softmax: bool, | |
| batch_request=None, | |
| ) -> List[List[float]]: | |
| """ | |
| Process results from multi-item scoring request. | |
| Extracts logprobs at delimiter positions from input_token_ids_logprobs. | |
| Args: | |
| results: Results from generate_request | |
| items: List of items being scored | |
| label_token_ids: Token IDs to extract scores for | |
| apply_softmax: Whether to apply softmax normalization | |
| batch_request: The original batch request containing input sequence | |
| Returns: | |
| List of score lists, one for each item | |
| """ | |
| single_result = results[0] if isinstance(results, list) else results | |
| # For multi-item scoring, logprobs are in input_token_ids_logprobs | |
| input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", []) | |
| if not input_logprobs: | |
| raise RuntimeError( | |
| f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. " | |
| "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring." | |
| ) | |
| scores = [] | |
| num_items = len(items) if isinstance(items, list) else 1 | |
| # Check if we have the expected number of logprobs | |
| expected_logprobs_count = num_items + 1 | |
| if len(input_logprobs) != expected_logprobs_count: | |
| raise RuntimeError( | |
| f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring " | |
| f"with {num_items} items, but got {len(input_logprobs)}. " | |
| f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}" | |
| ) | |
| # Skip the first delimiter (between query and first item) and process remaining delimiter positions | |
| # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary | |
| start_idx = 1 if len(input_logprobs) > 1 else 0 | |
| # Process logprobs for each item position (excluding first delimiter) | |
| for item_idx in range(num_items): | |
| logprob_idx = start_idx + item_idx | |
| item_logprobs_data = input_logprobs[logprob_idx] | |
| logprobs = self._extract_logprobs_for_tokens( | |
| item_logprobs_data, label_token_ids | |
| ) | |
| score_list = self._convert_logprobs_to_scores( | |
| logprobs, label_token_ids, apply_softmax | |
| ) | |
| scores.append(score_list) | |
| return scores | |
| def _process_single_item_scoring_results( | |
| self, results: Any, label_token_ids: List[int], apply_softmax: bool | |
| ) -> List[List[float]]: | |
| """ | |
| Process results from single-item scoring request. | |
| Single-item scoring results are stored in output_token_ids_logprobs. | |
| Args: | |
| results: Results from generate_request | |
| label_token_ids: Token IDs to extract scores for | |
| apply_softmax: Whether to apply softmax normalization | |
| Returns: | |
| List of score lists, one for each result | |
| """ | |
| scores = [] | |
| for result in results: | |
| # For single-item scoring, logprobs are in output_token_ids_logprobs | |
| output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) | |
| if not output_logprobs or len(output_logprobs) == 0: | |
| raise RuntimeError( | |
| f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}." | |
| ) | |
| # Extract logprobs for the first (and only) position | |
| logprobs = self._extract_logprobs_for_tokens( | |
| output_logprobs[0], label_token_ids | |
| ) | |
| score_list = self._convert_logprobs_to_scores( | |
| logprobs, label_token_ids, apply_softmax | |
| ) | |
| scores.append(score_list) | |
| return scores | |
| async def score_request( | |
| self, | |
| query: Optional[Union[str, List[int]]] = None, | |
| items: Optional[Union[str, List[str], List[List[int]]]] = None, | |
| label_token_ids: Optional[List[int]] = None, | |
| apply_softmax: bool = False, | |
| item_first: bool = False, | |
| request: Optional[Any] = None, | |
| ) -> List[List[float]]: | |
| """ | |
| Score the probability of specified token IDs appearing after the given (query + item) pair. | |
| This method supports two scoring approaches: | |
| 1. Single-Item scoring (default): Process each query+item pair independently | |
| 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and | |
| multiple items into a single sequence using delimiter for efficient processing. | |
| Note: item_first parameter is ignored in multi-item scoring mode since it uses | |
| a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter> | |
| Multi-item scoring works with both text and pre-tokenized inputs: | |
| - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text> | |
| - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id> | |
| Args: | |
| query: The query text or pre-tokenized query token IDs | |
| items: The item text(s) or pre-tokenized item token IDs | |
| label_token_ids: List of token IDs to compute probabilities for | |
| apply_softmax: Whether to normalize probabilities using softmax | |
| item_first: If True, prepend items to query. Ignored for multi-item scoring. | |
| request: Optional FastAPI request object | |
| Returns: | |
| List of lists containing probabilities for each item and each label token | |
| """ | |
| if label_token_ids is None: | |
| raise ValueError("label_token_ids must be provided") | |
| if self.tokenizer is not None: | |
| vocab_size = self.tokenizer.vocab_size | |
| for token_id in label_token_ids: | |
| if token_id >= vocab_size: | |
| raise ValueError( | |
| f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" | |
| ) | |
| # Check if multi-item scoring is enabled by presence of delimiter | |
| use_multi_item_scoring = ( | |
| self.server_args.multi_item_scoring_delimiter is not None | |
| and self.multi_item_delimiter_text is not None | |
| ) | |
| batch_request = GenerateReqInput( | |
| token_ids_logprob=label_token_ids, | |
| return_logprob=True, | |
| # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions | |
| logprob_start_len=0 if use_multi_item_scoring else -1, | |
| stream=False, | |
| sampling_params={"max_new_tokens": 0}, | |
| ) | |
| # Handle string or tokenized query/items | |
| if isinstance(query, str) and ( | |
| isinstance(items, str) | |
| or (isinstance(items, list) and (not items or isinstance(items[0], str))) | |
| ): | |
| # Both query and items are text | |
| items_list = [items] if isinstance(items, str) else items | |
| if use_multi_item_scoring: | |
| # Multi-item scoring: create single prompt with delimiter text | |
| # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter> | |
| # (item_first is ignored for multi-item scoring) | |
| delimiter = self.multi_item_delimiter_text | |
| combined_items = delimiter.join(items_list) | |
| # Add final delimiter after the last item for logprob extraction | |
| single_prompt = f"{query}{delimiter}{combined_items}{delimiter}" | |
| batch_request.text = [single_prompt] | |
| else: | |
| # Single-item scoring: create separate prompts for each item | |
| if item_first: | |
| prompts = [f"{item}{query}" for item in items_list] | |
| else: | |
| prompts = [f"{query}{item}" for item in items_list] | |
| batch_request.text = prompts | |
| elif ( | |
| isinstance(query, list) | |
| and isinstance(items, list) | |
| and items | |
| and isinstance(items[0], list) | |
| ): | |
| # Both query and items are token IDs | |
| if use_multi_item_scoring: | |
| # Multi-item scoring: concatenate with delimiter token ID | |
| # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id> | |
| delimiter_token_id = self.server_args.multi_item_scoring_delimiter | |
| combined_input_ids = self._build_multi_item_token_sequence( | |
| query, items, delimiter_token_id | |
| ) | |
| batch_request.input_ids = [combined_input_ids] | |
| else: | |
| # Single-item scoring: process each item separately | |
| if item_first: | |
| input_ids_list = [item + query for item in items] | |
| else: | |
| input_ids_list = [query + item for item in items] | |
| batch_request.input_ids = input_ids_list | |
| else: | |
| raise ValueError( | |
| "Invalid combination of query/items types for score_request." | |
| ) | |
| results = await self.generate_request(batch_request, request).__anext__() | |
| if use_multi_item_scoring: | |
| # Multi-item scoring: extract scores from input_token_ids_logprobs | |
| return self._process_multi_item_scoring_results( | |
| results, items, label_token_ids, apply_softmax, batch_request | |
| ) | |
| else: | |
| # Single-item scoring: process each result separately | |
| return self._process_single_item_scoring_results( | |
| results, label_token_ids, apply_softmax | |
| ) | |
| async def watch_load_thread(self): | |
| # Only for dp_controller when dp_size > 1 | |
| if ( | |
| self.server_args.dp_size == 1 | |
| or self.server_args.load_balance_method == "round_robin" | |
| ): | |
| return | |
| while True: | |
| await asyncio.sleep(self.server_args.load_watch_interval) | |
| loads = await self.get_load_communicator(GetLoadReqInput()) | |
| load_udpate_req = WatchLoadUpdateReq(loads=loads) | |
| self.send_to_scheduler.send_pyobj(load_udpate_req) | |
| def _trace_request_start( | |
| self, | |
| obj: Union[GenerateReqInput, EmbeddingReqInput], | |
| created_time: Optional[float] = None, | |
| ): | |
| if obj.is_single: | |
| bootstrap_room = ( | |
| obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None | |
| ) | |
| trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) | |
| trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) | |
| else: | |
| for i in range(len(obj.rid)): | |
| bootstrap_room = ( | |
| obj.bootstrap_room[i] | |
| if hasattr(obj, "bootstrap_room") and obj.bootstrap_room | |
| else None | |
| ) | |
| trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) | |
| trace_slice_start( | |
| "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True | |
| ) | |
| class ServerStatus(Enum): | |
| Up = "Up" | |
| Starting = "Starting" | |
| UnHealthy = "UnHealthy" | |
| def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: | |
| is_cross_node = server_args.dist_init_addr | |
| if is_cross_node: | |
| # Fallback to default CPU transport for multi-node | |
| return "default" | |
| else: | |
| return "cuda_ipc" | |
| async def print_exception_wrapper(func): | |
| """ | |
| Sometimes an asyncio function does not print exception. | |
| We do another wrapper to handle the exception. | |
| """ | |
| try: | |
| await func() | |
| except Exception: | |
| traceback = get_exception_traceback() | |
| logger.error(f"TokenizerManager hit an exception: {traceback}") | |
| if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager): | |
| func.__self__.dump_requests_before_crash() | |
| kill_process_tree(os.getpid(), include_parent=True) | |
| sys.exit(1) | |
| class SignalHandler: | |
| def __init__(self, tokenizer_manager: TokenizerManager): | |
| self.tokenizer_manager = tokenizer_manager | |
| def sigterm_handler(self, signum=None, frame=None): | |
| logger.warning( | |
| f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." | |
| ) | |
| self.tokenizer_manager.gracefully_exit = True | |
| def running_phase_sigquit_handler(self, signum=None, frame=None): | |
| logger.error( | |
| f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed." | |
| ) | |
| self.tokenizer_manager.dump_requests_before_crash() | |
| kill_process_tree(os.getpid()) | |
| # Note: request abort handling logic | |
| # We should handle all of the following cases correctly. | |
| # | |
| # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state | | |
| # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- | | |
| # | http | yes | validation | background task | fast api | del in _handle_abort_req | | |
| # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req | | |
| # | http | yes | running | background task | fast api | del in _handle_batch_output | | |
| # | http | no | validation | http exception | http exception | del in _handle_abort_req | | |
| # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req | | |
| # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output | | |
| # | |
Xet Storage Details
- Size:
- 93 kB
- Xet hash:
- 18e1984094cfca22f7ba85cafae7d0d7fefbdaea3d2074ebf69b4b06530bbba4
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.