| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """ | |
| The entry point of inference server. (SRT = SGLang Runtime) | |
| This file implements python APIs for the inference engine. | |
| """ | |
| import asyncio | |
| import atexit | |
| import dataclasses | |
| import logging | |
| import multiprocessing as mp | |
| import os | |
| import random | |
| import signal | |
| import threading | |
| import time | |
| from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union | |
| import zmq | |
| from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info | |
| # Fix a bug of Python threading | |
| setattr(threading, "_register_atexit", lambda *args, **kwargs: None) | |
| import torch | |
| import uvloop | |
| from sglang.srt.entrypoints.EngineBase import EngineBase | |
| from sglang.srt.managers.data_parallel_controller import ( | |
| run_data_parallel_controller_process, | |
| ) | |
| from sglang.srt.managers.detokenizer_manager import run_detokenizer_process | |
| from sglang.srt.managers.io_struct import ( | |
| DestroyWeightsUpdateGroupReqInput, | |
| EmbeddingReqInput, | |
| GenerateReqInput, | |
| GetWeightsByNameReqInput, | |
| InitWeightsUpdateGroupReqInput, | |
| LoadLoRAAdapterReqInput, | |
| MultimodalDataInputFormat, | |
| ReleaseMemoryOccupationReqInput, | |
| ResumeMemoryOccupationReqInput, | |
| RpcReqInput, | |
| RpcReqOutput, | |
| UnloadLoRAAdapterReqInput, | |
| UpdateWeightFromDiskReqInput, | |
| UpdateWeightsFromDistributedReqInput, | |
| UpdateWeightsFromIPCReqInput, | |
| UpdateWeightsFromTensorReqInput, | |
| ) | |
| from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter | |
| from sglang.srt.managers.scheduler import run_scheduler_process | |
| from sglang.srt.managers.template_manager import TemplateManager | |
| from sglang.srt.managers.tokenizer_manager import TokenizerManager | |
| from sglang.srt.server_args import PortArgs, ServerArgs | |
| from sglang.srt.utils import ( | |
| MultiprocessingSerializer, | |
| assert_pkg_version, | |
| configure_logger, | |
| get_bool_env_var, | |
| get_zmq_socket, | |
| is_cuda, | |
| kill_process_tree, | |
| launch_dummy_health_check_server, | |
| maybe_reindex_device_id, | |
| prepare_model_and_tokenizer, | |
| set_prometheus_multiproc_dir, | |
| set_ulimit, | |
| ) | |
| from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter | |
| from sglang.version import __version__ | |
| logger = logging.getLogger(__name__) | |
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
| _is_cuda = is_cuda() | |
| class Engine(EngineBase): | |
| """ | |
| The entry point to the inference engine. | |
| - The engine consists of three components: | |
| 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. | |
| 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. | |
| 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. | |
| Note: | |
| 1. The HTTP server, Engine, and TokenizerManager all run in the main process. | |
| 2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port. | |
| """ | |
| def __init__(self, **kwargs): | |
| """ | |
| The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. | |
| Please refer to `ServerArgs` for the documentation. | |
| """ | |
| if "server_args" in kwargs: | |
| # Directly load server_args | |
| server_args = kwargs["server_args"] | |
| else: | |
| # Construct server_args from kwargs | |
| if "log_level" not in kwargs: | |
| # Do not print logs by default | |
| kwargs["log_level"] = "error" | |
| server_args = ServerArgs(**kwargs) | |
| # Shutdown the subprocesses automatically when the program exits | |
| atexit.register(self.shutdown) | |
| # Allocate ports for inter-process communications | |
| self.port_args = PortArgs.init_new(server_args) | |
| logger.info(f"{server_args=}") | |
| # Launch subprocesses | |
| tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( | |
| server_args=server_args, | |
| port_args=self.port_args, | |
| ) | |
| self.server_args = server_args | |
| self.tokenizer_manager = tokenizer_manager | |
| self.template_manager = template_manager | |
| self.scheduler_info = scheduler_info | |
| context = zmq.Context(2) | |
| self.send_to_rpc = get_zmq_socket( | |
| context, zmq.DEALER, self.port_args.rpc_ipc_name, True | |
| ) | |
| if server_args.enable_trace: | |
| process_tracing_init(server_args.oltp_traces_endpoint, "sglang") | |
| if server_args.disaggregation_mode == "null": | |
| thread_label = "Tokenizer" | |
| trace_set_thread_info(thread_label) | |
| try: | |
| self.loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| self.loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self.loop) | |
| def generate( | |
| self, | |
| # The input prompt. It can be a single prompt or a batch of prompts. | |
| prompt: Optional[Union[List[str], str]] = None, | |
| sampling_params: Optional[Union[List[Dict], Dict]] = None, | |
| # The token ids for text; one can either specify text or input_ids. | |
| input_ids: Optional[Union[List[List[int]], List[int]]] = None, | |
| # The image input. It can be an image instance, file name, URL, or base64 encoded string. | |
| # Can be formatted as: | |
| # - Single image for a single request | |
| # - List of images (one per request in a batch) | |
| # - List of lists of images (multiple images per request) | |
| # See also python/sglang/srt/utils.py:load_image for more details. | |
| image_data: Optional[MultimodalDataInputFormat] = None, | |
| audio_data: Optional[MultimodalDataInputFormat] = None, | |
| video_data: Optional[MultimodalDataInputFormat] = None, | |
| return_logprob: Optional[Union[List[bool], bool]] = False, | |
| logprob_start_len: Optional[Union[List[int], int]] = None, | |
| top_logprobs_num: Optional[Union[List[int], int]] = None, | |
| token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, | |
| lora_path: Optional[List[Optional[str]]] = None, | |
| custom_logit_processor: Optional[Union[List[str], str]] = None, | |
| return_hidden_states: bool = False, | |
| stream: bool = False, | |
| bootstrap_host: Optional[Union[List[str], str]] = None, | |
| bootstrap_port: Optional[Union[List[int], int]] = None, | |
| bootstrap_room: Optional[Union[List[int], int]] = None, | |
| data_parallel_rank: Optional[int] = None, | |
| # blend additional parameters can be added here | |
| blend_style: Optional[str] = None, | |
| method: Optional[str] = "attn", | |
| separator: Optional[str] = "<|blendsep|>", | |
| start: Optional[int] = 0, | |
| ratio: Optional[float] = 0.3, | |
| attn_start: Optional[int] = 0, | |
| attn_end: Optional[int] = -1, | |
| is_contextblend: Optional[bool] = False, | |
| context_cache_source: Optional[str] = "query", | |
| context_n_sink: Optional[int] = None, | |
| digest_ratio: Optional[float] = 0.3, | |
| digest_index_method: Optional[str] = "kvzip", | |
| critical_layers: Optional[List[int]] = None, | |
| # SSD pipeline parameters | |
| ssd_cache_path_chunk: Optional[str] = None, | |
| ssd_cache_path_query: Optional[str] = None, | |
| ) -> Union[Dict, Iterator[Dict]]: | |
| """ | |
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. | |
| Please refer to `GenerateReqInput` for the documentation. | |
| """ | |
| if self.server_args.enable_dp_attention: | |
| if data_parallel_rank is None: | |
| logger.debug("data_parallel_rank not provided, using default dispatch") | |
| elif data_parallel_rank < 0: | |
| raise ValueError("data_parallel_rank must be non-negative") | |
| elif data_parallel_rank >= self.server_args.dp_size: | |
| raise ValueError( | |
| f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}" | |
| ) | |
| obj = GenerateReqInput( | |
| text=prompt, | |
| input_ids=input_ids, | |
| sampling_params=sampling_params, | |
| image_data=image_data, | |
| audio_data=audio_data, | |
| video_data=video_data, | |
| return_logprob=return_logprob, | |
| logprob_start_len=logprob_start_len, | |
| top_logprobs_num=top_logprobs_num, | |
| token_ids_logprob=token_ids_logprob, | |
| lora_path=lora_path, | |
| custom_logit_processor=custom_logit_processor, | |
| return_hidden_states=return_hidden_states, | |
| stream=stream, | |
| bootstrap_host=bootstrap_host, | |
| bootstrap_port=bootstrap_port, | |
| bootstrap_room=bootstrap_room, | |
| data_parallel_rank=data_parallel_rank, | |
| blend_style=blend_style, | |
| method=method, | |
| separator=separator, | |
| start=start, | |
| ratio=ratio, | |
| attn_start=attn_start, | |
| attn_end=attn_end, | |
| is_contextblend=is_contextblend, | |
| context_cache_source=context_cache_source, | |
| context_n_sink=context_n_sink, | |
| digest_ratio=digest_ratio, | |
| digest_index_method=digest_index_method, | |
| critical_layers=critical_layers, | |
| ssd_cache_path_chunk=ssd_cache_path_chunk, | |
| ssd_cache_path_query=ssd_cache_path_query, | |
| ) | |
| generator = self.tokenizer_manager.generate_request(obj, None) | |
| if stream: | |
| def generator_wrapper(): | |
| while True: | |
| try: | |
| chunk = self.loop.run_until_complete(generator.__anext__()) | |
| yield chunk | |
| except StopAsyncIteration: | |
| break | |
| return generator_wrapper() | |
| else: | |
| ret = self.loop.run_until_complete(generator.__anext__()) | |
| return ret | |
| async def async_generate( | |
| self, | |
| # The input prompt. It can be a single prompt or a batch of prompts. | |
| prompt: Optional[Union[List[str], str]] = None, | |
| sampling_params: Optional[Union[List[Dict], Dict]] = None, | |
| # The token ids for text; one can either specify text or input_ids. | |
| input_ids: Optional[Union[List[List[int]], List[int]]] = None, | |
| # The image input. It can be an image instance, file name, URL, or base64 encoded string. | |
| # Can be formatted as: | |
| # - Single image for a single request | |
| # - List of images (one per request in a batch) | |
| # - List of lists of images (multiple images per request) | |
| # See also python/sglang/srt/utils.py:load_image for more details. | |
| image_data: Optional[MultimodalDataInputFormat] = None, | |
| audio_data: Optional[MultimodalDataInputFormat] = None, | |
| video_data: Optional[MultimodalDataInputFormat] = None, | |
| return_logprob: Optional[Union[List[bool], bool]] = False, | |
| logprob_start_len: Optional[Union[List[int], int]] = None, | |
| top_logprobs_num: Optional[Union[List[int], int]] = None, | |
| token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, | |
| lora_path: Optional[List[Optional[str]]] = None, | |
| custom_logit_processor: Optional[Union[List[str], str]] = None, | |
| return_hidden_states: bool = False, | |
| stream: bool = False, | |
| bootstrap_host: Optional[Union[List[str], str]] = None, | |
| bootstrap_port: Optional[Union[List[int], int]] = None, | |
| bootstrap_room: Optional[Union[List[int], int]] = None, | |
| data_parallel_rank: Optional[int] = None, | |
| ) -> Union[Dict, AsyncIterator[Dict]]: | |
| """ | |
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. | |
| Please refer to `GenerateReqInput` for the documentation. | |
| """ | |
| if self.server_args.enable_dp_attention: | |
| if data_parallel_rank is None: | |
| logger.debug("data_parallel_rank not provided, using default dispatch") | |
| elif data_parallel_rank < 0: | |
| raise ValueError("data_parallel_rank must be non-negative") | |
| elif data_parallel_rank >= self.server_args.dp_size: | |
| raise ValueError( | |
| f"data_parallel_rank must be in range [0, {self.server_args.dp_size - 1}]" | |
| ) | |
| logger.debug(f"data_parallel_rank: {data_parallel_rank}") | |
| obj = GenerateReqInput( | |
| text=prompt, | |
| input_ids=input_ids, | |
| sampling_params=sampling_params, | |
| image_data=image_data, | |
| audio_data=audio_data, | |
| video_data=video_data, | |
| return_logprob=return_logprob, | |
| logprob_start_len=logprob_start_len, | |
| top_logprobs_num=top_logprobs_num, | |
| token_ids_logprob=token_ids_logprob, | |
| lora_path=lora_path, | |
| return_hidden_states=return_hidden_states, | |
| stream=stream, | |
| custom_logit_processor=custom_logit_processor, | |
| bootstrap_host=bootstrap_host, | |
| bootstrap_port=bootstrap_port, | |
| bootstrap_room=bootstrap_room, | |
| data_parallel_rank=data_parallel_rank, | |
| ) | |
| generator = self.tokenizer_manager.generate_request(obj, None) | |
| if stream is True: | |
| return generator | |
| else: | |
| return await generator.__anext__() | |
| def encode( | |
| self, | |
| prompt: Union[str, List[str], List[Dict], List[List[Dict]]], | |
| image_data: Optional[MultimodalDataInputFormat] = None, | |
| audio_data: Optional[MultimodalDataInputFormat] = None, | |
| video_data: Optional[MultimodalDataInputFormat] = None, | |
| ) -> Dict: | |
| """ | |
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. | |
| Please refer to `EmbeddingReqInput` for the documentation. | |
| """ | |
| obj = EmbeddingReqInput( | |
| text=prompt, | |
| image_data=image_data, | |
| audio_data=audio_data, | |
| video_data=video_data, | |
| ) | |
| generator = self.tokenizer_manager.generate_request(obj, None) | |
| ret = self.loop.run_until_complete(generator.__anext__()) | |
| return ret | |
| async def async_encode( | |
| self, | |
| prompt: Union[str, List[str], List[Dict], List[List[Dict]]], | |
| image_data: Optional[MultimodalDataInputFormat] = None, | |
| audio_data: Optional[MultimodalDataInputFormat] = None, | |
| video_data: Optional[MultimodalDataInputFormat] = None, | |
| ) -> Dict: | |
| """ | |
| Asynchronous version of encode method. | |
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. | |
| Please refer to `EmbeddingReqInput` for the documentation. | |
| """ | |
| obj = EmbeddingReqInput( | |
| text=prompt, | |
| image_data=image_data, | |
| audio_data=audio_data, | |
| video_data=video_data, | |
| ) | |
| generator = self.tokenizer_manager.generate_request(obj, None) | |
| return await generator.__anext__() | |
| def rerank( | |
| self, | |
| prompt: Union[List[List[str]]], | |
| ) -> Dict: | |
| """ | |
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. | |
| Please refer to `EmbeddingReqInput` for the documentation. | |
| """ | |
| obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True) | |
| generator = self.tokenizer_manager.generate_request(obj, None) | |
| ret = self.loop.run_until_complete(generator.__anext__()) | |
| return ret | |
| def shutdown(self): | |
| """Shutdown the engine""" | |
| kill_process_tree(os.getpid(), include_parent=False) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| self.shutdown() | |
| return False | |
| def flush_cache(self): | |
| return self.loop.run_until_complete(self.tokenizer_manager.flush_cache()) | |
| def start_profile(self, **kwargs): | |
| self.loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs)) | |
| def stop_profile(self): | |
| self.loop.run_until_complete(self.tokenizer_manager.stop_profile()) | |
| def start_expert_distribution_record(self): | |
| self.loop.run_until_complete( | |
| self.tokenizer_manager.start_expert_distribution_record() | |
| ) | |
| def stop_expert_distribution_record(self): | |
| self.loop.run_until_complete( | |
| self.tokenizer_manager.stop_expert_distribution_record() | |
| ) | |
| def dump_expert_distribution_record(self): | |
| self.loop.run_until_complete( | |
| self.tokenizer_manager.dump_expert_distribution_record() | |
| ) | |
| def get_server_info(self): | |
| internal_states = self.loop.run_until_complete( | |
| self.tokenizer_manager.get_internal_state() | |
| ) | |
| return { | |
| **dataclasses.asdict(self.tokenizer_manager.server_args), | |
| **self.scheduler_info, | |
| "internal_states": internal_states, | |
| "version": __version__, | |
| } | |
| def init_weights_update_group( | |
| self, | |
| master_address: str, | |
| master_port: int, | |
| rank_offset: int, | |
| world_size: int, | |
| group_name: str, | |
| backend: str = "nccl", | |
| ): | |
| """Initialize parameter update group.""" | |
| obj = InitWeightsUpdateGroupReqInput( | |
| master_address=master_address, | |
| master_port=master_port, | |
| rank_offset=rank_offset, | |
| world_size=world_size, | |
| group_name=group_name, | |
| backend=backend, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.init_weights_update_group(obj, None) | |
| ) | |
| def destroy_weights_update_group( | |
| self, | |
| group_name: str, | |
| ): | |
| """Destroy parameter update group.""" | |
| obj = DestroyWeightsUpdateGroupReqInput( | |
| group_name=group_name, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.destroy_weights_update_group(obj, None) | |
| ) | |
| def update_weights_from_distributed( | |
| self, | |
| names: list[str], | |
| dtypes: list[str], | |
| shapes: list[list[int]], | |
| group_name: str = "weight_update_group", | |
| flush_cache: bool = True, | |
| ): | |
| """Update weights from distributed source.""" | |
| obj = UpdateWeightsFromDistributedReqInput( | |
| names=names, | |
| dtypes=dtypes, | |
| shapes=shapes, | |
| group_name=group_name, | |
| flush_cache=flush_cache, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.update_weights_from_distributed(obj, None) | |
| ) | |
| def update_weights_from_tensor( | |
| self, | |
| named_tensors: List[Tuple[str, torch.Tensor]], | |
| load_format: Optional[str] = None, | |
| flush_cache: bool = True, | |
| ): | |
| """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false | |
| to avoid duplicated cache cleaning operation.""" | |
| if load_format == "flattened_bucket": | |
| serialized_named_tensors = named_tensors | |
| else: | |
| serialized_named_tensors = [ | |
| MultiprocessingSerializer.serialize(named_tensors) | |
| for _ in range(self.server_args.tp_size) | |
| ] | |
| obj = UpdateWeightsFromTensorReqInput( | |
| serialized_named_tensors=serialized_named_tensors, | |
| load_format=load_format, | |
| flush_cache=flush_cache, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.update_weights_from_tensor(obj, None) | |
| ) | |
| def update_weights_from_disk( | |
| self, | |
| model_path: str, | |
| load_format: Optional[str] = None, | |
| ): | |
| """Update the weights from disk inplace without re-launching the engine. | |
| This method allows updating the model weights from disk without restarting | |
| the engine. It can be used to load a different model or update weights with | |
| new training. | |
| """ | |
| obj = UpdateWeightFromDiskReqInput( | |
| model_path=model_path, | |
| load_format=load_format, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.update_weights_from_disk(obj, None) | |
| ) | |
| def get_weights_by_name(self, name: str, truncate_size: int = 100): | |
| """Get weights by parameter name.""" | |
| obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.get_weights_by_name(obj, None) | |
| ) | |
| def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False): | |
| """Load a new LoRA adapter without re-launching the engine.""" | |
| obj = LoadLoRAAdapterReqInput( | |
| lora_name=lora_name, | |
| lora_path=lora_path, | |
| pinned=pinned, | |
| ) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.load_lora_adapter(obj, None) | |
| ) | |
| def unload_lora_adapter(self, lora_name: str): | |
| """Unload a LoRA adapter without re-launching the engine.""" | |
| obj = UnloadLoRAAdapterReqInput(lora_name=lora_name) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.unload_lora_adapter(obj, None) | |
| ) | |
| def release_memory_occupation(self, tags: Optional[List[str]] = None): | |
| obj = ReleaseMemoryOccupationReqInput(tags=tags) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.release_memory_occupation(obj, None) | |
| ) | |
| def resume_memory_occupation(self, tags: Optional[List[str]] = None): | |
| obj = ResumeMemoryOccupationReqInput(tags=tags) | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.resume_memory_occupation(obj, None) | |
| ) | |
| def freeze_gc(self): | |
| """ | |
| To maintain a high performance server with low latency, we want to reduce the | |
| stalls caused by the garbage collector scanning through a large number of objects. | |
| It is usually helpful to start the server and warm it up with real requests to | |
| initialize many of the long-lived objects that do not need to be garbage collected. | |
| After sufficient warmup, we can call this function to freeze the garbage collector | |
| so that all objects created before this point are considered out of scope for garbage | |
| collection. | |
| """ | |
| self.loop.run_until_complete(self.tokenizer_manager.freeze_gc()) | |
| """ | |
| Execute an RPC call on all scheduler processes. | |
| """ | |
| def collective_rpc(self, method: str, **kwargs): | |
| obj = RpcReqInput(method=method, parameters=kwargs) | |
| self.send_to_rpc.send_pyobj(obj) | |
| recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY) | |
| assert isinstance(recv_req, RpcReqOutput) | |
| assert recv_req.success, recv_req.message | |
| def save_remote_model(self, **kwargs): | |
| self.collective_rpc("save_remote_model", **kwargs) | |
| def save_sharded_model(self, **kwargs): | |
| self.collective_rpc("save_sharded_model", **kwargs) | |
| def score( | |
| self, | |
| query: Optional[Union[str, List[int]]] = None, | |
| items: Optional[Union[str, List[str], List[List[int]]]] = None, | |
| label_token_ids: Optional[List[int]] = None, | |
| apply_softmax: bool = False, | |
| item_first: bool = False, | |
| ) -> List[List[float]]: | |
| """ | |
| Score the probability of specified token IDs appearing after the given (query + item) pair. For example: | |
| query = "<|user|>Is the following city the capital of France? " | |
| items = ["Paris <|assistant|>", "London <|assistant|>", "Berlin <|assistant|>"] | |
| label_token_ids = [2332, 1223] # Token IDs for "Yes" and "No" | |
| item_first = False | |
| This would pass the following prompts to the model: | |
| "<|user|>Is the following city the capital of France? Paris <|assistant|>" | |
| "<|user|>Is the following city the capital of France? London <|assistant|>" | |
| "<|user|>Is the following city the capital of France? Berlin <|assistant|>" | |
| The api would then return the probabilities of the model producing "Yes" and "No" as the next token. | |
| The output would look like: | |
| [[0.9, 0.1], [0.2, 0.8], [0.1, 0.9]] | |
| Args: | |
| query: The query text or pre-tokenized query token IDs. Must be provided. | |
| items: The item text(s) or pre-tokenized item token IDs. Must be provided. | |
| label_token_ids: List of token IDs to compute probabilities for. If None, no token probabilities will be computed. | |
| apply_softmax: Whether to normalize probabilities using softmax. | |
| item_first: If True, prepend items to query. Otherwise append items to query. | |
| Returns: | |
| List of dictionaries mapping token IDs to their probabilities for each item. | |
| Each dictionary in the list corresponds to one item input. | |
| Raises: | |
| ValueError: If query is not provided, or if items is not provided, | |
| or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens. | |
| """ | |
| return self.loop.run_until_complete( | |
| self.tokenizer_manager.score_request( | |
| query=query, | |
| items=items, | |
| label_token_ids=label_token_ids, | |
| apply_softmax=apply_softmax, | |
| item_first=item_first, | |
| request=None, | |
| ) | |
| ) | |
| async def async_score( | |
| self, | |
| query: Optional[Union[str, List[int]]] = None, | |
| items: Optional[Union[str, List[str], List[List[int]]]] = None, | |
| label_token_ids: Optional[List[int]] = None, | |
| apply_softmax: bool = False, | |
| item_first: bool = False, | |
| ) -> List[List[float]]: | |
| """ | |
| Asynchronous version of score method. | |
| See score() for detailed documentation. | |
| """ | |
| return await self.tokenizer_manager.score_request( | |
| query=query, | |
| items=items, | |
| label_token_ids=label_token_ids, | |
| apply_softmax=apply_softmax, | |
| item_first=item_first, | |
| request=None, | |
| ) | |
| def update_weights_from_ipc( | |
| self, | |
| zmq_handles: Dict[str, str], | |
| flush_cache: bool = True, | |
| ): | |
| """Update weights from IPC for checkpoint-engine integration.""" | |
| obj = UpdateWeightsFromIPCReqInput( | |
| zmq_handles=zmq_handles, | |
| flush_cache=flush_cache, | |
| ) | |
| loop = asyncio.get_event_loop() | |
| return loop.run_until_complete( | |
| self.tokenizer_manager.update_weights_from_ipc(obj, None) | |
| ) | |
| def _set_envs_and_config(server_args: ServerArgs): | |
| # Set global environments | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) | |
| if not server_args.enable_symm_mem: | |
| os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) | |
| os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" | |
| os.environ["CUDA_MODULE_LOADING"] = "AUTO" | |
| # flashinfer uses this environment variable for various kernels from MoE to quant kernels | |
| if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0": | |
| os.environ["TRTLLM_ENABLE_PDL"] = "1" | |
| if os.environ.get("CUTE_DSL_LOG_LEVEL") is None: | |
| # Default to warning level, to avoid too many logs | |
| os.environ["CUTE_DSL_LOG_LEVEL"] = "30" | |
| if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None: | |
| # Need to set log to console, otherwise the log level won't take effect | |
| os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1" | |
| # Can also be passed as argument | |
| os.environ["SGLANG_RUN_ID"] = ( | |
| f"sglang-run-{time.time()}-{random.randint(0, 100000000)}" | |
| ) | |
| # Set prometheus env vars | |
| if server_args.enable_metrics: | |
| set_prometheus_multiproc_dir() | |
| # Set ulimit | |
| set_ulimit() | |
| # Check flashinfer version | |
| if server_args.attention_backend == "flashinfer": | |
| assert_pkg_version( | |
| "flashinfer_python", | |
| "0.4.1", | |
| "Please uninstall the old version and " | |
| "reinstall the latest version by following the instructions " | |
| "at https://docs.flashinfer.ai/installation.html.", | |
| ) | |
| if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): | |
| assert_pkg_version( | |
| "sgl-kernel", | |
| "0.3.16.post3", | |
| "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", | |
| ) | |
| if True: # Keep this check for internal code compatibility | |
| # Register the signal handler. | |
| # The child processes will send SIGQUIT to this process when any error happens | |
| # This process then clean up the whole process tree | |
| # Note: This sigquit handler is used in the launch phase, and may be replaced by | |
| # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched. | |
| def launch_phase_sigquit_handler(signum, frame): | |
| logger.error( | |
| "Received sigquit from a child process. It usually means the child failed." | |
| ) | |
| kill_process_tree(os.getpid()) | |
| signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) | |
| # Set mp start method | |
| mp.set_start_method("spawn", force=True) | |
| def _init_tokenizer_manager( | |
| server_args: ServerArgs, port_args: PortArgs | |
| ) -> TokenizerManager: | |
| # Launch tokenizer process | |
| tokenizer_manager = TokenizerManager(server_args, port_args) | |
| # Initialize templates | |
| template_manager = TemplateManager() | |
| template_manager.initialize_templates( | |
| tokenizer_manager=tokenizer_manager, | |
| model_path=server_args.model_path, | |
| chat_template=server_args.chat_template, | |
| completion_template=server_args.completion_template, | |
| ) | |
| return tokenizer_manager, template_manager | |
| def _launch_subprocesses( | |
| server_args: ServerArgs, port_args: Optional[PortArgs] = None | |
| ) -> Tuple[TokenizerManager, TemplateManager, Dict]: | |
| """ | |
| Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. | |
| """ | |
| # Configure global environment | |
| configure_logger(server_args) | |
| server_args.check_server_args() | |
| _set_envs_and_config(server_args) | |
| # Allocate ports for inter-process communications | |
| if port_args is None: | |
| port_args = PortArgs.init_new(server_args) | |
| logger.info(f"{server_args=}") | |
| # If using model from www.modelscope.cn, first download the model. | |
| server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( | |
| server_args.model_path, server_args.tokenizer_path | |
| ) | |
| scheduler_procs = [] | |
| if server_args.dp_size == 1: | |
| memory_saver_adapter = TorchMemorySaverAdapter.create( | |
| enable=server_args.enable_memory_saver | |
| ) | |
| scheduler_pipe_readers = [] | |
| nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) | |
| tp_size_per_node = server_args.tp_size // nnodes_per_tp_group | |
| tp_rank_range = range( | |
| tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), | |
| tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), | |
| ) | |
| pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) | |
| pp_rank_range = range( | |
| pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), | |
| pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), | |
| ) | |
| for pp_rank in pp_rank_range: | |
| for tp_rank in tp_rank_range: | |
| reader, writer = mp.Pipe(duplex=False) | |
| gpu_id = ( | |
| server_args.base_gpu_id | |
| + ((pp_rank % pp_size_per_node) * tp_size_per_node) | |
| + (tp_rank % tp_size_per_node) * server_args.gpu_id_step | |
| ) | |
| moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) | |
| with maybe_reindex_device_id(gpu_id) as gpu_id: | |
| proc = mp.Process( | |
| target=run_scheduler_process, | |
| args=( | |
| server_args, | |
| port_args, | |
| gpu_id, | |
| tp_rank, | |
| moe_ep_rank, | |
| pp_rank, | |
| None, | |
| writer, | |
| ), | |
| ) | |
| with memory_saver_adapter.configure_subprocess(): | |
| proc.start() | |
| scheduler_procs.append(proc) | |
| scheduler_pipe_readers.append(reader) | |
| else: | |
| # Launch the data parallel controller | |
| reader, writer = mp.Pipe(duplex=False) | |
| scheduler_pipe_readers = [reader] | |
| proc = mp.Process( | |
| target=run_data_parallel_controller_process, | |
| args=(server_args, port_args, writer), | |
| ) | |
| proc.start() | |
| scheduler_procs.append(proc) | |
| if server_args.node_rank >= 1: | |
| # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, | |
| # so they can just wait here. | |
| for reader in scheduler_pipe_readers: | |
| data = reader.recv() | |
| assert data["status"] == "ready" | |
| if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": | |
| # When using `Engine` as a Python API, we don't want to block here. | |
| return None, None, None | |
| launch_dummy_health_check_server( | |
| server_args.host, server_args.port, server_args.enable_metrics | |
| ) | |
| for proc in scheduler_procs: | |
| proc.join() | |
| logger.error( | |
| f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" | |
| ) | |
| return None, None, None | |
| # Launch detokenizer process | |
| detoken_proc = mp.Process( | |
| target=run_detokenizer_process, | |
| args=( | |
| server_args, | |
| port_args, | |
| ), | |
| ) | |
| detoken_proc.start() | |
| # Init tokenizer manager first, as the bootstrap server is initialized here | |
| if server_args.tokenizer_worker_num > 1: | |
| # Launch multi-tokenizer router | |
| tokenizer_manager = MultiTokenizerRouter(server_args, port_args) | |
| template_manager = None | |
| else: | |
| tokenizer_manager, template_manager = _init_tokenizer_manager( | |
| server_args, port_args | |
| ) | |
| # Wait for the model to finish loading | |
| scheduler_infos = [] | |
| for i in range(len(scheduler_pipe_readers)): | |
| try: | |
| data = scheduler_pipe_readers[i].recv() | |
| except EOFError: | |
| logger.error( | |
| f"Rank {i} scheduler is dead. Please check if there are relevant logs." | |
| ) | |
| scheduler_procs[i].join() | |
| logger.error(f"Exit code: {scheduler_procs[i].exitcode}") | |
| raise | |
| if data["status"] != "ready": | |
| raise RuntimeError( | |
| "Initialization failed. Please see the error messages above." | |
| ) | |
| scheduler_infos.append(data) | |
| # Assume all schedulers have the same scheduler_info | |
| scheduler_info = scheduler_infos[0] | |
| tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] | |
| return tokenizer_manager, template_manager, scheduler_info | |
Xet Storage Details
- Size:
- 36.4 kB
- Xet hash:
- 61957c85b80a49ef74bd0048669ae1e0c946760c8e8ebae1637b3d6af574fd69
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.