diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..96818507d589fca1f4d602df8a25111d82a4e5d2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" +import asyncio +import json +import ssl +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f04902ae1c7678c736bcb49450339eeafcbb6b75 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py @@ -0,0 +1,1007 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import codecs +import json +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from functools import cache, lru_cache, partial +from pathlib import Path +from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, + Literal, Optional, Tuple, TypeVar, Union, cast) + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) +# yapf: enable +# pydantic needs the TypedDict from typing_extensions +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import MediaConnector +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class VideoURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the video or a data URL with base64 encoded video data. + """ + + +class ChatCompletionContentPartVideoParam(TypedDict, total=False): + video_url: Required[VideoURL] + + type: Required[Literal["video_url"]] + """The type of the content part.""" + + +class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain image_url. + This is supported by OpenAI API, although it is not documented. + + Example: + { + "image_url": "https://example.com/image.jpg" + } + """ + image_url: Required[str] + + +class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "audio_url": "https://example.com/audio.mp3" + } + """ + audio_url: Required[str] + + +class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "video_url": "https://example.com/video.mp4" + } + """ + video_url: Required[str] + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentSimpleImageParam, + CustomChatCompletionContentSimpleAudioParam, + CustomChatCompletionContentSimpleVideoParam, str] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Union[Optional[str], List[Dict[str, str]]] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + tokenizer_chat_template = tokenizer.chat_template + else: + tokenizer_chat_template = None + + jinja_text: Optional[str] + if isinstance(tokenizer_chat_template, str) and chat_template is None: + jinja_text = tokenizer_chat_template + elif (isinstance(tokenizer_chat_template, dict) + and chat_template in tokenizer_chat_template): + jinja_text = tokenizer_chat_template[chat_template] + else: + jinja_text = load_chat_template(chat_template, is_literal=True) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format if given_format == "auto" else given_format + + +@lru_cache +def resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + given_format, + tokenizer, + ) + + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + return detected_format + + +ModalityStr = Literal["image", "audio", "video"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + + self._items_by_modality = defaultdict[str, list[_T]](list) + + @property + def model_config(self) -> ModelConfig: + return self._model_config + + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + + @staticmethod + @cache + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type in ("minicpmo", "minicpmv"): + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", + "pixtral"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat", + "NVLM_D", "h2ovl_chat"): + return "" + if model_type == "mllama": + return "<|image|>" + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" + if model_type == "idefics3": + return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" + + raise TypeError(f"Unknown {modality} model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|audio|>" + if model_type == "qwen2_audio": + return (f"Audio {current_count}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>") + if model_type == "minicpmo": + return "()" + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if model_type in ("minicpmo", "minicpmv"): + return "()" + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.video_token_index) + raise TypeError(f"Unknown {modality} model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = len(self._items_by_modality[modality]) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._items_by_modality[modality].append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[object]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items_by_modality: + return dict(self._items_by_modality) + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items_by_modality: + return { + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> Dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_input_audio(self, input_audio: InputAudio) -> None: + raise NotImplementedError + + @abstractmethod + def parse_video(self, video_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + + def parse_image(self, image_url: str) -> None: + image = self._connector.fetch_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = self._connector.fetch_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video(video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + + def parse_image(self, image_url: str) -> None: + image_coro = self._connector.fetch_image_async(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = self._connector.fetch_audio_async(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video_async(video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + +def load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + if chat_template is None: + return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return codecs.decode(chat_template, "unicode_escape") + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return load_chat_template(chat_template, is_literal=True) + + +# TODO: Let user specify how to insert multimodal tokens into prompt +# (similar to chat template) +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: List[str] = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) + +_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio] + +# Define a mapping from part types to their corresponding parsing functions. +MM_PARSER_MAP: Dict[ + str, + Callable[[ChatCompletionContentPartParam], _ContentPart], +] = { + "text": + lambda part: _TextParser(part).get("text", ""), + "image_url": + lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), + "audio_url": + lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), + "input_audio": + lambda part: _InputAudioParser(part).get("input_audio", {}), + "refusal": + lambda part: _RefusalParser(part).get("refusal", ""), + "video_url": + lambda part: _VideoParser(part).get("video_url", {}).get("url", ""), +} + + +def _parse_chat_message_content_mm_part( + part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + """ + Parses a given multi-modal content part based on its type. + + Args: + part: A dict containing the content part, with a potential 'type' field. + + Returns: + A tuple (part_type, content) where: + - part_type: Type of the part (e.g., 'text', 'image_url'). + - content: Parsed content (e.g., text, image URL). + + Raises: + ValueError: If the 'type' field is missing and no direct URL is found. + """ + assert isinstance( + part, dict) # This is needed to avoid mypy errors: part.get() from str + part_type = part.get("type", None) + + if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + content = MM_PARSER_MAP[part_type](part) + + # Special case for 'image_url.detail' + # We only support 'auto', which is the default + if part_type == "image_url" and part.get("detail", "auto") != "auto": + logger.warning("'image_url.detail' is currently not supported " + "and will be ignored.") + + return part_type, content + + # Handle missing 'type' but provided direct URL fields. + # 'type' is required field by pydantic + if part_type is None: + if part.get("image_url") is not None: + image_params = cast(CustomChatCompletionContentSimpleImageParam, + part) + return "image_url", image_params.get("image_url", "") + if part.get("audio_url") is not None: + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, + part) + return "audio_url", audio_params.get("audio_url", "") + if part.get("input_audio") is not None: + input_audio_params = cast(Dict[str, str], part) + return "input_audio", input_audio_params + if part.get("video_url") is not None: + video_params = cast(CustomChatCompletionContentSimpleVideoParam, + part) + return "video_url", video_params.get("video_url", "") + # Raise an error if no 'type' or direct URL is found. + raise ValueError("Missing 'type' field in multimodal part.") + + if not isinstance(part_type, str): + raise ValueError("Invalid 'type' field in multimodal part.") + return part_type, "unknown part_type content" + + +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "audio_url", "input_audio", "video_url") + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, +) -> List[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, + content=content)] # type: ignore + texts = cast(List[str], content) + text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[_ContentPart]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + return part + + # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but + # content is empty, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: + logger.warning( + "Skipping multimodal part (type: '%s')" + "with empty / unparsable content.", part_type) + return None + + if part_type in ("text", "refusal"): + str_content = cast(str, content) + if wrap_dicts: + return {'type': 'text', 'text': str_content} + else: + return str_content + + if part_type == "image_url": + str_content = cast(str, content) + mm_parser.parse_image(str_content) + return {'type': 'image'} if wrap_dicts else None + + if part_type == "audio_url": + str_content = cast(str, content) + mm_parser.parse_audio(str_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "input_audio": + dict_content = cast(InputAudio, content) + mm_parser.parse_input_audio(dict_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "video_url": + str_content = cast(str, content) + mm_parser.parse_video(str_content) + return {'type': 'video'} if wrap_dicts else None + + raise NotImplementedError(f"Unknown part type: {part_type}") + + +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, +) -> List[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + ) + + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _postprocess_messages(messages: List[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: List[ConversationMessage] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + conversation: List[ConversationMessage], + chat_template: Optional[str], + *, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + if chat_template is None and tokenizer.chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) + + +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str] = None, + **kwargs: Any, +) -> List[int]: + if chat_template is not None: + logger.warning_once( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning_once( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning_once( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + + return tokenizer.apply_chat_template( + messages=messages, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..351a39525fa621870b7d60192478abcd6a5746f2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import signal +from http import HTTPStatus +from typing import Any + +import uvicorn +from fastapi import FastAPI, Request, Response + +from vllm import envs +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.logger import init_logger +from vllm.utils import find_process_using_port + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + _add_shutdown_handlers(app, server) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + port = uvicorn_kwargs["port"] + process = find_process_using_port(port) + if process is not None: + logger.debug( + "port %s is used by process %s launched with command:\n%s", + port, process, " ".join(process.cmdline())) + logger.info("Shutting down FastAPI HTTP server.") + return server.shutdown() + + +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: + """Adds handlers for fatal errors that should crash the server""" + + @app.exception_handler(RuntimeError) + async def runtime_error_handler(request: Request, __): + """On generic runtime error, check to see if the engine has died. + It probably has, in which case the server will no longer be able to + handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client + if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored + and not engine.is_running): + logger.fatal("AsyncLLMEngine has failed, terminating server " + "process") + # See discussions here on shutting down a uvicorn server + # https://github.com/encode/uvicorn/discussions/1103 + # In this case we cannot await the server shutdown here because + # this handler must first return to close the connection for + # this request. + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(AsyncEngineDeadError) + async def async_engine_dead_handler(_, __): + """Kill the server if the async engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("AsyncLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..d071a0b3cfc5d313bc0ef0055d861eca461acb6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py @@ -0,0 +1,1414 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import warnings +from contextlib import contextmanager +from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence, + Tuple, Type, Union, cast, overload) + +import cloudpickle +import torch +import torch.nn as nn +from tqdm import tqdm +from typing_extensions import TypeVar, deprecated + +from vllm import envs +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig +from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, + TaskOption) +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages, + resolve_chat_template_content_format) +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt +from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) +from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, + PoolingRequestOutput, RequestOutput, + ScoringRequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_cached_tokenizer) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of + +logger = init_logger(__name__) + +_R = TypeVar("_R", default=Any) + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` + disable_async_output_proc: Disable async output processing. + This may result in lower performance. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine-args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + """ + + DEPRECATE_LEGACY: ClassVar[bool] = True + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + DEPRECATE_INIT_POSARGS: ClassVar[bool] = True + """ + A flag to toggle whether to deprecate positional arguments in + :meth:`LLM.__init__`. + """ + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + + @deprecate_args( + start_index=2, # Ignore self and model + is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS, + additional_message=( + "All positional arguments other than `model` will be " + "replaced with keyword arguments in an upcoming version."), + ) + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + **kwargs, + ) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if compilation_config is not None: + if isinstance(compilation_config, (int, dict)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = None + + engine_args = EngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() + self.llm_engine = self.engine_class.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) + + self.request_counter = Counter() + + @staticmethod + def get_engine_class() -> Type[LLMEngine]: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + return V1LLMEngine # type: ignore + return LLMEngine + + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + + # While CachedTokenizer is dynamic, have no choice but + # compare class name. Misjudgment will arise from + # user-defined tokenizer started with 'Cached' + if tokenizer.__class__.__name__.startswith("Cached"): + tokenizer_group.tokenizer = tokenizer + else: + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + + def get_default_sampling_params(self) -> SamplingParams: + diff_sampling_param = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if diff_sampling_param: + return SamplingParams.from_optional(**diff_sampling_param) + return SamplingParams() + + @overload + def generate( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: List[str], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def generate( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. + + Returns: + A list of ``RequestOutput`` objects containing the + generated completions in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "generate": + messages = [ + "LLM.generate() is only supported for (conditional) generation " + "models (XForCausalLM, XForConditionalGeneration).", + ] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "generate" in supported_runner_types: + messages.append( + "Your model supports the 'generate' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task generate`.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + + if sampling_params is None: + # Use default sampling params. + sampling_params = self.get_default_sampling_params() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request, + priority=priority) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, RequestOutput) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + :exc:`TimeoutError` on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + executor = self.llm_engine.model_executor + return executor.collective_rpc(method, timeout, args, kwargs) + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + """ + Run a function directly on the model inside each worker, + returning the result for each of them. + """ + executor = self.llm_engine.model_executor + return executor.apply_model(func) + + def beam_search( + self, + prompts: List[Union[TokensPrompt, TextPrompt]], + params: BeamSearchParams, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + params: The beam search parameters. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + if is_token_prompt(prompt): + prompt_tokens = prompt["prompt_token_ids"] + else: + prompt_tokens = tokenizer.encode(prompt["prompt"]) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=sort_beams_key, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + + def chat( + self, + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[RequestOutput]: + """ + Generate responses for a chat conversation. + + The chat conversation is converted into a text prompt using the + tokenizer and calls the :meth:`generate` method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. + + Args: + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + + - "string" will render the content as a string. + Example: ``"Who are you?"`` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: ``[{"type": "text", "text": "Who are you?"}]`` + + add_generation_prompt: If True, adds a generation template + to each message. + continue_final_message: If True, continues the final message in + the conversation instead of starting a new one. Cannot be + ``True`` if ``add_generation_prompt`` is also ``True``. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. + + Returns: + A list of ``RequestOutput`` objects containing the generated + responses in the same order as the input messages. + """ + list_of_messages: List[List[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = cast(List[List[ChatCompletionMessageParam]], + messages) + else: + # messages is List[...] + list_of_messages = [ + cast(List[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return self.generate( + prompts, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + @overload + def encode( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: List[str], + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def encode( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + """Apply pooling to the hidden states corresponding to the input + prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``PoolingRequestOutput`` objects containing the + pooled hidden states in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.encode() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + def embed( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """ + Generate an embedding vector for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "embed": + raise ValueError( + "Embedding API is only enabled for `--task embed`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [EmbeddingRequestOutput.from_base(item) for item in items] + + def classify( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[ClassificationRequestOutput]: + """ + Generate class logits for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``ClassificationRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "classify": + raise ValueError( + "Classification API is only enabled for `--task classify`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [ClassificationRequestOutput.from_base(item) for item in items] + + def _embedding_score( + self, + tokenizer: AnyTokenizer, + text_1: List[Union[str, TextPrompt, TokensPrompt]], + text_2: List[Union[str, TextPrompt, TokensPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[ScoringRequestOutput]: + + encoded_output = self.encode( + text_1 + text_2, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + encoded_output_1 = encoded_output[0:len(text_1)] + encoded_output_2 = encoded_output[len(text_1):] + + if len(encoded_output_1) == 1: + encoded_output_1 = encoded_output_1 * len(encoded_output_2) + + output_pairs = [(t1, t2) + for t1, t2 in zip(encoded_output_1, encoded_output_2)] + + scores = [] + scorer = torch.nn.CosineSimilarity(0) + + for embed_1, embed_2 in output_pairs: + pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data) + + if (pad_token_id := getattr(tokenizer, "pad_token_id", + None)) is not None: + tokens = embed_1.prompt_token_ids + [ + pad_token_id + ] + embed_2.prompt_token_ids + else: + tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{embed_1.request_id}_{embed_2.request_id}", + outputs=pair_score, + prompt_token_ids=tokens, + finished=True)) + + items = self.engine_class.validate_outputs(scores, + PoolingRequestOutput) + return [ScoringRequestOutput.from_base(item) for item in items] + + def _cross_encoding_score( + self, + tokenizer: Union[AnyTokenizer], + text_1: List[Union[str, TextPrompt, TokensPrompt]], + text_2: List[Union[str, TextPrompt, TokensPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[ScoringRequestOutput]: + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "Score API is only enabled for `--task embed or score`") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + items = self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + return [ScoringRequestOutput.from_base(item) for item in items] + + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[ScoringRequestOutput]: + """Generate similarity scores for all pairs ````. + + The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. + In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N`` + times to pair with the ``text_2`` sentences. + The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the ``text_2`` list + text_2: The texts to pair with the query to form the input + to the LLM. See :class:`~vllm.inputs.PromptType` for + more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``ScoringRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.score() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if self.llm_engine.model_config.task not in ("embed", "score"): + raise ValueError( + "Score API is only enabled for `--task embed or --task score`") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + tokenizer = self.llm_engine.get_tokenizer() + + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for scoring") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [ensure_str(t) for t in text_2] + + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if self.llm_engine.model_config.is_cross_encoder: + return self._cross_encoding_score(tokenizer, text_1, text_2, + truncate_prompt_tokens, use_tqdm, + lora_request, + prompt_adapter_request) + else: + return self._embedding_score(tokenizer, text_1, text_2, + truncate_prompt_tokens, use_tqdm, + lora_request, prompt_adapter_request) + + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + + def reset_prefix_cache(self) -> bool: + return self.llm_engine.reset_prefix_cache() + + def sleep(self, level: int = 1): + """ + Put the engine to sleep. The engine should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + :param level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache is + forgotten. Level 1 sleep is good for sleeping and waking up the + engine to run the same model again. The model weights are backed + up in CPU memory. Please make sure there's enough CPU memory to + store the model weights. Level 2 sleep will discard both the model + weights and the kv cache. The content of both the model weights + and kv cache is forgotten. Level 2 sleep is good for sleeping and + waking up the engine to run a different model or update the model, + where previous model weights are not needed. It reduces CPU memory + pressure. + """ + self.reset_prefix_cache() + self.llm_engine.sleep(level=level) + + def wake_up(self): + """ + Wake up the engine from sleep mode. See the :meth:`sleep` method + for more details.""" + self.llm_engine.wake_up() + + # LEGACY + def _convert_v1_inputs( + self, + prompts: Optional[Union[str, List[str]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + ): + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + parsed_prompts: List[PromptType] = [] + for i in range(num_requests): + item: PromptType + + if prompts is not None: + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + parsed_prompts.append(item) + + return parsed_prompts + + def _validate_and_add_requests( + self, + prompts: Union[PromptType, Sequence[PromptType]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, + ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options_request is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(prompts, (str, dict)): + # Convert a single prompt to a list. + prompts = [prompts] + + num_requests = len(prompts) + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + if isinstance(lora_request, + list) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") + + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_params(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY + + # Add requests to the engine. + for i, prompt in enumerate(prompts): + self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, + ) + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request( + request_id, + prompt, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + def _add_guided_params( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options is None: + return params + + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and" + "params.guided_decoding.") + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern) + return params + + def _run_engine( + self, *, use_tqdm: bool + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: List[Union[RequestOutput, PoolingRequestOutput]] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs) + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s") + pbar.update(1) + + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e82b6ba6c7bae3c0496f395324b4a405dc34c435 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams + +logger = init_logger(__name__) + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..330e65ab7feb7a2db9dcc3f2366ed207683ad9c7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7310b71eb9a9c9b0ff15bc85ec7b115c541aa57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d6937ebc725460735b31517653dfb3e2eb47f6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..babb3d685ce86bd72d7050fa1e121b4b98d87822 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86c6e91c74ca0b2036ecafd74e24e84b464ad9ff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690ed28b842ebb3c176133d30fe6c68820ae134a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3aaec581e7664e9c51a05cd41fdc879b3e8552c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39e1468d7d795a72d041e35cde4d263e9644bbe2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19f5b2291b1df5efe239aaecb32ff8b5be6eb592 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752568888dc2ceb5a4efebed1b24d59a264ae57f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c34c10a91572ae85f625681d18bd15958c5af810 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1316222cb6225825dc5fe78f8e50f22643ad396 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79ba6c4a759b007ccf8216c54332bf6c9269b2a7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c42fd11416eda0363a2ddf6904ca5be5e45498f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..302849af071e2523826c59f33850bab7ba10c55f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f54d6c78042dd944bd5410dd3a12ded13fa877 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py @@ -0,0 +1,911 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import atexit +import gc +import importlib +import inspect +import multiprocessing +import os +import re +import signal +import socket +import sys +import tempfile +import uuid +from argparse import Namespace +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union + +import uvloop +from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State +from starlette.routing import Mount +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, + LoadLoraAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, + ScoreRequest, ScoreResponse, + TokenizeRequest, + TokenizeResponse, + UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank +from vllm.entrypoints.openai.serving_score import OpenAIServingScores +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address, set_ulimit) +from vllm.version import __version__ as VLLM_VERSION + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: Set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # AsyncLLMEngine. + if (MQLLMEngineClient.is_unsupported_config(engine_args) + or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): + + engine_client: Optional[EngineClient] = None + try: + engine_client = AsyncLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.OPENAI_API_SERVER) + yield engine_client + finally: + if engine_client and hasattr(engine_client, "shutdown"): + engine_client.shutdown() + + # MQLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path, engine_alive)) + engine_process.start() + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + def _cleanup_ipc_path(): + socket_path = ipc_path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + # Ensure we clean up the local IPC socket file on exit. + atexit.register(_cleanup_ipc_path) + + # Build RPCClient, which conforms to EngineClient Protocol. + engine_config = engine_args.create_engine_config() + build_client = partial(MQLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) + try: + while True: + try: + await mq_engine_client.setup() + break + except TimeoutError: + if (not engine_process.is_alive() + or not engine_alive.value): + raise RuntimeError( + "Engine process failed to start. See stack " + "trace for the root cause.") from None + + yield mq_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mq_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +router = APIRouter() + + +def mount_metrics(app: FastAPI): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + +def chat(request: Request) -> Optional[OpenAIServingChat]: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> Optional[OpenAIServingCompletion]: + return request.app.state.openai_serving_completion + + +def pooling(request: Request) -> Optional[OpenAIServingPooling]: + return request.app.state.openai_serving_pooling + + +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding + + +def score(request: Request) -> Optional[OpenAIServingScores]: + return request.app.state.openai_serving_scores + + +def rerank(request: Request) -> Optional[JinaAIServingRerank]: + return request.app.state.jinaai_serving_reranking + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health") +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.api_route("/ping", methods=["GET", "POST"]) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post("/tokenize") +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize") +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + handler = models(raw_request) + + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions") +@with_cancellation +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") + + generator = await handler.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions") +@with_cancellation +async def create_completion(request: CompletionRequest, raw_request: Request): + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings") +@with_cancellation +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + handler = embedding(raw_request) + if handler is None: + fallback_handler = pooling(raw_request) + if fallback_handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + logger.warning( + "Embeddings API will become exclusive to embedding models " + "in a future release. To return the hidden states directly, " + "use the Pooling API (`/pooling`) instead.") + + res = await fallback_handler.create_pooling(request, raw_request) + + generator: Union[ErrorResponse, EmbeddingResponse] + if isinstance(res, PoolingResponse): + generator = EmbeddingResponse( + id=res.id, + object=res.object, + created=res.created, + model=res.model, + data=[ + EmbeddingResponseData( + index=d.index, + embedding=d.data, # type: ignore + ) for d in res.data + ], + usage=res.usage, + ) + else: + generator = res + else: + generator = await handler.create_embedding(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/pooling") +@with_cancellation +async def create_pooling(request: PoolingRequest, raw_request: Request): + handler = pooling(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Pooling API") + + generator = await handler.create_pooling(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, PoolingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/score") +@with_cancellation +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/score") +@with_cancellation +async def create_score_v1(request: ScoreRequest, raw_request: Request): + logger.warning( + "To indicate that Score API is not part of standard OpenAI API, we " + "have moved it to `/score`. Please update your client accordingly.") + + return await create_score(request, raw_request) + + +@router.post("/rerank") +@with_cancellation +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank") +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning_once( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client" + "accordingly. (Note: Conforms to JinaAI rerank API)") + + return await do_rerank(request, raw_request) + + +@router.post("/v2/rerank") +@with_cancellation +async def do_rerank_v2(request: RerankRequest, raw_request: Request): + return await do_rerank(request, raw_request) + + +TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "default": (RerankRequest, do_rerank) + }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, + "reward": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, + "classify": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, +} + +if envs.VLLM_SERVER_DEV_MODE: + + @router.post("/reset_prefix_cache") + async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + logger.info("Resetting prefix cache...") + await engine_client(raw_request).reset_prefix_cache() + return Response(status_code=200) + + +@router.post("/invocations") +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + body = await raw_request.json() + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}") + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + if token := envs.VLLM_API_KEY or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + if request.method == "OPTIONS": + return await call_next(request) + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + if args.enable_request_id_headers: + logger.warning( + "CAUTION: Enabling X-Request-Id headers in the API Server. " + "This can harm performance at high QPS.") + + @app.middleware("http") + async def add_request_id(request: Request, call_next): + request_id = request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + response = await call_next(request) + response.headers["X-Request-Id"] = request_id + return response + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) # type: ignore[arg-type] + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +async def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + await state.openai_serving_models.init_static_loras() + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + enable_reasoning=args.enable_reasoning, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) if model_config.runner_type == "generate" else None + state.openai_serving_pooling = OpenAIServingPooling( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.runner_type == "pooling" else None + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.task == "embed" else None + state.openai_serving_scores = OpenAIServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger + ) if model_config.task == "score" else None + state.jinaai_serving_reranking = JinaAIServingRerank( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger + ) if model_config.task == "score" else None + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) + state.task = model_config.task + + +def create_server_socket(addr: Tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addr) + + return sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valid_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valid_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})") + + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.enable_reasoning \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + await init_app_state(engine_client, model_config, app.state, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + # Workaround to work on macOS + fd=sock.fileno() if sys.platform.startswith("darwin") else None, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py new file mode 100644 index 0000000000000000000000000000000000000000..3054958f3c8abc1e618c21a7fe260169efa6f23f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl +from typing import List, Optional, Sequence, Union, get_args + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, + PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.utils import FlexibleArgumentParser + + +class LoRAParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] + for item in values: + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=nullable_str, + default=None, + help="Host name.") + parser.add_argument("--port", type=int, default=8000, help="Port number.") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="Log level for uvicorn.") + parser.add_argument("--allow-credentials", + action="store_true", + help="Allow credentials.") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="Allowed origins.") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="Allowed methods.") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="Allowed headers.") + parser.add_argument("--api-key", + type=nullable_str, + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): ``'name=path'`` " + "Example (new format): " + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=nullable_str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model.") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: ``"Hello World"``\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: ``[{"type": "text", "text": "Hello world!"}]``') + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "``request.add_generation_prompt=true``.") + parser.add_argument("--ssl-keyfile", + type=nullable_str, + default=None, + help="The file path to the SSL key file.") + parser.add_argument("--ssl-certfile", + type=nullable_str, + default=None, + help="The file path to the SSL cert file.") + parser.add_argument("--ssl-ca-certs", + type=nullable_str, + default=None, + help="The CA certificates file.") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)." + ) + parser.add_argument( + "--root-path", + type=nullable_str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy." + ) + parser.add_argument( + "--middleware", + type=nullable_str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using ``@app.middleware('http')``. " + "If a class is provided, vLLM will add it to the server " + "using ``app.add_middleware()``. ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When ``--max-logprobs`` is specified, represents single tokens " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + parser.add_argument( + "--enable-request-id-headers", + action="store_true", + help="If specified, API server will add X-Request-Id header to " + "responses. Caution: this hurts performance at high QPS.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help="Enable auto tool choice for supported models. Use " + "``--tool-call-parser`` to specify which parser to use.") + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content.") + + valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() + parser.add_argument( + "--reasoning-parser", + type=str, + metavar="{" + ",".join(valid_reasoning_parsers) + "}", + default=None, + help= + "Select the reasoning parser depending on the model that you're using." + " This is used to parse the reasoning content into OpenAI API " + "format. Required for ``--enable-reasoning``.") + + valid_tool_parsers = ToolParserManager.tool_parsers.keys() + parser.add_argument( + "--tool-call-parser", + type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for ``--enable-auto-tool-choice``.") + + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in ``--tool-call-parser``.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + + return parser + + +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + # Enable reasoning needs a reasoning parser to be valid + if args.enable_reasoning and not args.reasoning_parser: + raise TypeError("Error: --enable-reasoning requires " + "--reasoning-parser") + + # Ref https://api-docs.deepseek.com/guides/reasoning_model + # tool call and reasoning cannot be enabled at the same time. + if args.enable_auto_tool_choice and args.enable_reasoning: + raise TypeError( + "Error: --enable-auto-tool-choice and " + "--enable-reasoning cannot be enabled at the same time") + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..41e5eef40eaf82ce3afa59523a2ed24796338abe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import lru_cache, partial +from typing import Dict, FrozenSet, Iterable, List, Optional, Union + +import torch + +from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[List[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: FrozenSet[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def logit_bias_logits_processor( + logit_bias: Dict[int, float], + token_ids: List[int], + logits: torch.Tensor, +) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + +def get_logits_processors( + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: AnyTokenizer, +) -> List[LogitsProcessor]: + logits_processors: List[LogitsProcessor] = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: Dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id, bias in clamped_logit_bias.items(): + if token_id < 0 or token_id >= len(tokenizer): + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), len(tokenizer))) + + return logits_processors diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..83b841826231ef17c2426f73adfda256942f41ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py @@ -0,0 +1,1428 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import re +import time +from argparse import Namespace +from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union + +import torch +from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, + ValidationInfo, field_validator, model_validator) +from typing_extensions import Annotated + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.sequence import Logprob +from vllm.utils import random_uuid, resolve_obj_by_qualname + +logger = init_logger(__name__) + +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + # Cache class field names + field_names: ClassVar[Optional[Set[str]]] = None + + @model_validator(mode="wrap") + @classmethod + def __log_extra_fields__(cls, data, handler): + result = handler(data) + if not isinstance(data, dict): + return result + field_names = cls.field_names + if field_names is None: + # Get all class field names and their potential aliases + field_names = set() + for field_name, field in cls.model_fields.items(): + field_names.add(field_name) + if alias := getattr(field, 'alias', None): + field_names.add(alias) + cls.field_names = field_names + + # Compare against both field names and aliases + if any(k not in field_names for k in data): + logger.warning( + "The following fields were present in the request " + "but ignored: %s", + data.keys() - field_names) + return result + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class PromptTokenUsageInfo(OpenAIBaseModel): + cached_tokens: Optional[int] = None + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + prompt_tokens_details: Optional[PromptTokenUsageInfo] = None + + +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = False + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class LogitsProcessorConstructor(BaseModel): + qualname: str + args: Optional[List[Any]] = None + kwargs: Optional[Dict[str, Any]] = None + + +LogitsProcessors = List[Union[str, LogitsProcessorConstructor]] + + +def get_logits_processors(processors: Optional[LogitsProcessors], + pattern: Optional[str]) -> Optional[List[Any]]: + if processors and pattern: + logits_processors = [] + for processor in processors: + qualname = processor if isinstance(processor, + str) else processor.qualname + if not re.match(pattern, qualname): + raise ValueError( + f"Logits processor '{qualname}' is not allowed by this " + "server. See --logits-processor-pattern engine argument " + "for more information.") + try: + logits_processor = resolve_obj_by_qualname(qualname) + except Exception as e: + raise ValueError( + f"Logits processor '{qualname}' could not be resolved: {e}" + ) from e + if isinstance(processor, LogitsProcessorConstructor): + logits_processor = logits_processor(*processor.args or [], + **processor.kwargs or {}) + logits_processors.append(logits_processor) + return logits_processors + elif processors: + raise ValueError( + "The `logits_processors` argument is not supported by this " + "server. See --logits-processor-pattern engine argugment " + "for more information.") + return None + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens: Optional[int] = Field( + default=None, + deprecated= + 'max_tokens is deprecated in favor of the max_completion_tokens field') + max_completion_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], Literal["auto"], + ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response.")) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + + # doc: end-chat-completion-extra-params + + # Default sampling parameters for chat completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = self.max_completion_tokens or self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output) + + def to_sampling_params( + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = self.max_completion_tokens or self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + + guided_json_object = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif self.response_format.type == "json_schema": + json_schema = self.response_format.json_schema + assert json_schema is not None + self.guided_json = json_schema.json_schema + if self.guided_decoding_backend is None: + self.guided_decoding_backend = "xgrammar" + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=prompt_logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens, + min_tokens=self.min_tokens, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and data.get("tools"): + data["tool_choice"] = "auto" + + # if "tool_choice" is "none" -- ignore tools if present + if "tool_choice" in data and data["tool_choice"] == "none": + # ensure that no tools are present + data.pop("tools", None) + return data + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool, \"auto\", " + "or \"none\".") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"].get("function") + if not specified_function: + raise ValueError( + "Expected field `function` in `tool_choice`." + " Correct usage: `{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function.get("name") + if not specified_function_name: + raise ValueError( + "Expected field `name` in `function` in `tool_choice`." + "Correct usage: `{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") + return data + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[List[int]] = None + prompt_logprobs: Optional[int] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " + "{'type': 'text' } is supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description="If specified, the output will follow the JSON schema.", + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + + # doc: end-completion-extra-params + + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: + max_tokens = self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output) + + def to_sampling_params( + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None) -> SamplingParams: + max_tokens = self.max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +class EmbeddingCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-embedding-pooling-params + + # doc: begin-embedding-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-embedding-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class EmbeddingChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + # doc: begin-chat-embedding-extra-params + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + # doc: end-chat-embedding-extra-params + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] + +PoolingCompletionRequest = EmbeddingCompletionRequest +PoolingChatRequest = EmbeddingChatRequest +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + + +class ScoreRequest(OpenAIBaseModel): + model: str + text_1: Union[List[str], str] + text_2: Union[List[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-score-pooling-params + additional_data: Optional[Any] = None + # doc: end-score-pooling-params + + # doc: begin-score-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-score-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class RerankRequest(OpenAIBaseModel): + model: str + query: str + documents: List[str] + top_n: int = Field(default_factory=lambda: 0) + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-rerank-pooling-params + additional_data: Optional[Any] = None + # doc: end-rerank-pooling-params + + # doc: begin-rerank-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-rerank-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class RerankDocument(BaseModel): + text: str + + +class RerankResult(BaseModel): + index: int + document: RerankDocument + relevance_score: float + + +class RerankUsage(BaseModel): + total_tokens: int + + +class RerankResponse(OpenAIBaseModel): + id: str + model: str + usage: RerankUsage + results: List[RerankResult] + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class PoolingResponseData(OpenAIBaseModel): + index: int + object: str = "pooling" + data: Union[List[List[float]], List[float], str] + + +class PoolingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"pool-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[PoolingResponseData] + usage: UsageInfo + + +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: float + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[ScoreResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + +class ChatMessage(OpenAIBaseModel): + role: str + reasoning_content: Optional[str] = None + content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: List[DeltaToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + + @field_validator('body', mode='plain') + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url = info.data['url'] + if url == "/v1/chat/completions": + return ChatCompletionRequest.model_validate(value) + if url == "/v1/embeddings": + return TypeAdapter(EmbeddingRequest).validate_python(value) + if url == "/v1/score": + return ScoreRequest.model_validate(value) + return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest]).validate_python(value) + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, + ScoreResponse]] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: List[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..675d3cdcf97155073c07d8afbdd5e5878b23ba78 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from http import HTTPStatus +from io import StringIO +from typing import Awaitable, Callable, List, Optional + +import aiohttp +import torch +from prometheus_client import start_http_server +from tqdm import tqdm + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger, logger +# yapf: disable +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, ErrorResponse, + ScoreResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_score import OpenAIServingScores +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION + + +def parse_args(): + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=True`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument("--enable-metrics", + action="store_true", + help="Enable Prometheus metrics") + parser.add_argument( + "--url", + type=str, + default="0.0.0.0", + help="URL to the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port number for the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + + return parser.parse_args() + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, encoding="utf-8") as f: + return f.read() + + +async def write_file(path_or_url: str, data: str) -> None: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.put(path_or_url, data=data.encode("utf-8")): + pass + else: + # We should make this async, but as long as this is always run as a + # standalone program, blocking the event loop won't effect performance + # in this particular case. + with open(path_or_url, "w", encoding="utf-8") as f: + f.write(data) + + +def make_error_request_output(request: BatchRequestInput, + error_msg: str) -> BatchRequestOutput: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=HTTPStatus.BAD_REQUEST, + request_id=f"vllm-batch-{random_uuid()}", + ), + error=error_msg, + ) + return batch_output + + +async def make_async_error_request_output( + request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + return make_error_request_output(request, error_msg) + + +async def run_request(serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: + response = await serving_engine_func(request.body) + + if isinstance(response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + body=response, request_id=f"vllm-batch-{random_uuid()}"), + error=None, + ) + elif isinstance(response, ErrorResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=response.code, + request_id=f"vllm-batch-{random_uuid()}"), + error=response, + ) + else: + batch_output = make_error_request_output( + request, error_msg="Request must not be sent in stream mode") + + tracker.completed() + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) + + model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + # Create the openai serving objects. + openai_serving_models = OpenAIServingModels( + engine_client=engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + ) + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + ) if model_config.task == "embed" else None + openai_serving_scores = (OpenAIServingScores( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + ) if model_config.task == "score" else None) + + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + + # Submit all requests in the file to the engine "concurrently". + response_futures: List[Awaitable[BatchRequestOutput]] = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + # Skip empty lines. + request_json = request_json.strip() + if not request_json: + continue + + request = BatchRequestInput.model_validate_json(request_json) + + # Determine the type of request and run it. + if request.url == "/v1/chat/completions": + handler_fn = (None if openai_serving_chat is None else + openai_serving_chat.create_chat_completion) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "The model does not support Chat Completions API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) + tracker.submitted() + elif request.url == "/v1/embeddings": + handler_fn = (None if openai_serving_embedding is None else + openai_serving_embedding.create_embedding) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Embeddings API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) + tracker.submitted() + elif request.url == "/v1/score": + handler_fn = (None if openai_serving_scores is None else + openai_serving_scores.create_score) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Scores API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) + tracker.submitted() + else: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "Only /v1/chat/completions, /v1/embeddings, and /v1/score " + "are supported in the batch endpoint.", + )) + + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) + + output_buffer = StringIO() + for response in responses: + print(response.model_dump_json(), file=output_buffer) + + output_buffer.seek(0) + await write_file(args.output_file, output_buffer.read().strip()) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM batch processing API version %s", VLLM_VERSION) + logger.info("args: %s", args) + + # Start the Prometheus metrics server. LLMEngine uses the Prometheus client + # to publish metrics at the /metrics endpoint. + if args.enable_metrics: + logger.info("Prometheus metrics enabled") + start_http_server(port=args.port, addr=args.url) + else: + logger.info("Prometheus metrics disabled") + + asyncio.run(main(args)) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..107220d548afc0adf506fb24c6b1a10ea2a3b232 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,955 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +import time +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) +from typing import Sequence as GenericSequence +from typing import Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, + RequestResponseMetadata, ToolCall, UsageInfo) +from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, + ReasoningParserManager) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + enable_reasoning: bool = False, + reasoning_parser: Optional[str] = None, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.response_role = response_role + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.enable_reasoning: bool = enable_reasoning + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if self.enable_reasoning: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + except Exception as e: + raise TypeError("Error: --enable-reasoning requires " + f"reasoning_parser:'{reasoning_parser}' " + "which has not been registered") from e + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + try: + if (tool_parser == "pythonic" and + model_config.model.startswith("meta-llama/Llama-3.2")): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic" + " tool calls") + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: + raise TypeError("Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e + + self.enable_prompt_tokens_details = enable_prompt_tokens_details + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info("Overwriting default chat sampling param with: %s", + diff_sampling_param) + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: + """ + Chat Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + Chat Completion API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_name = self.models.model_name(lora_request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + tool_parser = self.tool_parser + + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + if isinstance(tokenizer, MistralTokenizer): + maybe_serialize_tool_calls(request) + + if (request.tool_choice == "auto" and + not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer)): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set" + ) + + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + request_id = "chatcmpl-" \ + f"{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens, default_sampling_params) + else: + sampling_params = request.to_sampling_params( + default_max_tokens, + self.model_config.logits_processor_pattern, + default_sampling_params) + + self._log_inputs(request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) + + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + return request.messages[-1]["role"] + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + + all_previous_token_ids: Optional[List[List[int]]] + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or should_stream_with_reasoning_parsing: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + + try: + # There is no need to check if the reasoning_parser is None + # because the should_stream_with_reasoning_parsing check + # already ensures that the reasoning_parser is not None. + # but the pre-commit hook requires it. + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + num_cached_tokens = res.num_cached_tokens + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: Union[str, List[Dict[str, str]]] = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + delta_text = output.text + + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: Optional[DeltaMessage] + + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + ]) + + # handle streaming deltas for tools with "auto" tool choice + elif tool_choice_auto: + assert previous_texts is not None + assert all_previous_token_ids is not None + assert tool_parser is not None + #TODO optimize manipulation of these lists + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + # reasoning_content cannot be enabled with tool_choice. + # If it is, the tool_choice will be used instead. + elif self.enable_reasoning: + # handle reasoning_content delta + assert reasoning_parser is not None + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + latest_delta_len = 0 + if ((isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): + latest_delta_len = len( + delta_message.tool_calls[0].function. + arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {}), + ensure_ascii=False) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + if (latest_delta_len > 0): + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + + finish_reason_sent[i] = True + + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: List[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + ) + else: + logprobs = None + + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = False + + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content( + output.text, request=request)) + + if reasoning_content: + message = ChatMessage(role=role, + content=content, + reasoning_content=reasoning_content) + else: + message = ChatMessage(role=role, content=output.text) + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + elif (not self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, content=output.text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + message = ChatMessage( + role=role, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output.text)) + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + output.text, request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, content=output.text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, content=output.text) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, List[Dict[str, str]]] = "" + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg['text'] + for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=final_res.prompt_logprobs, + ) + + return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: AnyTokenizer, + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: List[ChatCompletionLogProbsContent] = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + bytes=list(token.encode("utf-8", errors="replace")), + )) + else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self._get_decoded_token( + step_token, + token_id, + tokenizer, + self.return_tokens_as_token_ids, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + ), + )) + + return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_stream_with_reasoning_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the + reasoning parser that was configured. + + We only want to do this IF reasoning is enabled and a reasoning + parser is configured. + """ + return self.enable_reasoning and self.reasoning_parser is not None + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ad263e7fbe5049dcab508139b67d0bbc2d5aa9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + RequestResponseMetadata, + UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + async def create_completion( + self, + request: CompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following feature: + - suffix (the language models we currently support do not support + suffix) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + request_prompts, engine_prompts = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens, default_sampling_params) + else: + sampling_params = request.to_sampling_params( + default_max_tokens, + self.model_config.logits_processor_pattern, + default_sampling_params) + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + model_name = self.models.model_name(lora_request) + num_prompts = len(engine_prompts) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use + # beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=num_prompts, + tokenizer=tokenizer, + request_metadata=request_metadata) + + # Non-streaming response + final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + for i, final_res in enumerate(final_res_batch): + assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + final_res.prompt = request_prompts[i]["prompt"] + + final_res_batch_checked = cast(List[RequestOutput], + final_res_batch) + + response = self.request_output_to_completion_response( + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + tokenizer, + request_metadata, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + num_choices = 1 if request.n is None else request.n + previous_text_lens = [0] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs + prompt_text = res.prompt + + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[ + int, Logprob]]]] + + for output in res.outputs: + i = output.index + prompt_idx * num_choices + + assert request.max_tokens is not None + if request.echo and not has_echoed[i]: + assert prompt_token_ids is not None + assert prompt_text is not None + if request.max_tokens == 0: + # only return the prompt + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + else: + assert prompt_logprobs is not None + # echo the prompt and first token + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, + *(output.logprobs or []), + ] + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs + + if not delta_text and not delta_token_ids \ + and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_completion_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.logprobs, + tokenizer=tokenizer, + initial_text_offset=previous_text_lens[i], + ) + else: + logprobs = None + + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + stop_reason=stop_reason, + ) + ]) + if include_continuous_usage: + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + response_json = chunk.model_dump_json(exclude_unset=False) + yield f"data: {response_json}\n\n" + + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + + if include_usage: + final_usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=final_usage_info, + ) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = final_usage_info + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: List[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> CompletionResponse: + choices: List[CompletionResponseChoice] = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + + for final_res in final_res_batch: + prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None + prompt_logprobs = final_res.prompt_logprobs + if prompt_logprobs: + for logprob_dict in prompt_logprobs: + if logprob_dict: + for logprob_values in logprob_dict.values(): + if logprob_values.logprob == float('-inf'): + logprob_values.logprob = -9999.0 + prompt_text = final_res.prompt + + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[int, + Logprob]]]] + + for output in final_res.outputs: + assert request.max_tokens is not None + if request.echo: + assert prompt_text is not None + if request.max_tokens == 0: + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + output_text = prompt_text + else: + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + out_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_completion_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + prompt_logprobs=final_res.prompt_logprobs, + ) + choices.append(choice_data) + + num_generated_tokens += len(output.token_ids) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + request_metadata.final_usage_info = usage + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: int, + tokenizer: AnyTokenizer, + initial_text_offset: int = 0, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + + last_token_len = 0 + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" + + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + step_token = step_top_logprobs[token_id] + + token = self._get_decoded_token( + step_token, + token_id, + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids, + ) + token_logprob = max(step_token.logprob, -9999.0) + + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..45f8ad90ddcb3d67d56e49ccfa39bcc4ec2d135d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import base64 +import time +from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast + +import numpy as np +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, + PoolingRequestOutput) +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[List[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + embedding_bytes = np.array(output.embedding, dtype="float32").tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """ + Embedding API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"embd-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(request, EmbeddingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + # In embedding requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_embedding_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + encoding_format, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_embedding_response( + self, + final_res_batch: List[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + encoding_format: Literal["float", "base64"], + ) -> EmbeddingResponse: + items: List[EmbeddingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + embedding_res = EmbeddingRequestOutput.from_base(final_res) + + item = EmbeddingResponseData( + index=idx, + embedding=_get_embedding(embedding_res.outputs, + encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..8d39fdcb748330545539b1e71c8cae88d71483c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from concurrent.futures.thread import ThreadPoolExecutor +from http import HTTPStatus +from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, + Optional, Sequence, Tuple, TypedDict, Union) + +from fastapi import Request +from pydantic import Field +from starlette.datastructures import Headers +from typing_extensions import Annotated + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest, + DetokenizeRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ErrorResponse, RerankRequest, + ScoreRequest, + TokenizeChatRequest, + TokenizeCompletionRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser +# yapf: enable +from vllm.inputs import TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import is_list_of, make_async, random_uuid + +logger = init_logger(__name__) + +CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, + EmbeddingCompletionRequest, ScoreRequest, + TokenizeCompletionRequest] + +ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, + TokenizeChatRequest] + +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + +RequestPrompt = Union[List[int], str, TextTokensPrompt] + + +class OpenAIServing: + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__() + + self.engine_client = engine_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len + + self.models = models + + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + + async def _check_model( + self, + request: AnyRequest, + ) -> Optional[ErrorResponse]: + if self._is_model_supported(request.model): + return None + if request.model in [ + lora.lora_name for lora in self.models.lora_requests + ]: + return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.models.prompt_adapter_requests + ]: + return None + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ + None, PromptAdapterRequest]]: + if self._is_model_supported(request.model): + return None, None + for lora in self.models.lora_requests: + if request.model == lora.lora_name: + return lora, None + for prompt_adapter in self.models.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return None, prompt_adapter + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _normalize_prompt_text_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + add_special_tokens: bool, + ) -> TextTokensPrompt: + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + prompt = prompt.lower() + + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_ids: List[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(input_ids) + + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: AnyRequest, + input_ids: List[int], + input_text: str, + ) -> TextTokensPrompt: + token_num = len(input_ids) + + # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest, + ScoreRequest, RerankRequest)): + + operation = "score" if isinstance(request, ScoreRequest) \ + else "embedding generation" + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for {operation}. " + f"Please reduce the length of the input.") + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # chat completion endpoint supports max_completion_tokens + if isinstance(request, ChatCompletionRequest): + # TODO(#9845): remove max_tokens when field dropped from OpenAI API + max_tokens = request.max_completion_tokens or request.max_tokens + else: + max_tokens = request.max_tokens + if max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.") + elif token_num + max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.") + + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + def _tokenize_prompt_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, List[int]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> TextTokensPrompt: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes single input. + """ + return next( + self._tokenize_prompt_inputs( + request, + tokenizer, + [prompt_input], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + )) + + def _tokenize_prompt_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_inputs: Iterable[Union[str, List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes multiple inputs. + """ + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=text, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _tokenize_prompt_input_or_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> List[TextTokensPrompt]: + """ + Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ] + + async def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]: + request_prompts = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts = [ + TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) + for request_prompt in request_prompts + ] + + return request_prompts, engine_prompts + + async def _preprocess_chat( + self, + request: ChatLikeRequest, + tokenizer: AnyTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, str]]] = None, + chat_template_kwargs: Optional[Dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = False, + ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], + List[TokensPrompt]]: + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) + conversation, mm_data_future = parse_chat_messages_futures( + messages, + self.model_config, + tokenizer, + content_format=resolved_content_format, + ) + + _chat_template_kwargs: Dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + request_prompt: Union[str, List[int]] + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + **_chat_template_kwargs, + ) + + mm_data = await mm_data_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and (hasattr( + request, "tool_choice") and request.tool_choice != "none") + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request) + + if isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), ( + "Prompt has to be either a string or a list of token ids") + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) + + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + + return conversation, [request_prompt], [engine_prompt] + + def _log_inputs( + self, + request_id: str, + inputs: RequestPrompt, + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + if self.request_logger is None: + return + + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = None + elif isinstance(inputs, list): + prompt = None + prompt_token_ids = inputs + else: + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Optional[Mapping[str, str]]: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + @staticmethod + def _base_request_id(raw_request: Optional[Request], + default: Optional[str] = None) -> Optional[str]: + """Pulls the request id to use from a header, if provided""" + default = default or random_uuid() + if raw_request is None: + return default + + return raw_request.headers.get("X-Request-Id", default) + + @staticmethod + def _get_decoded_token(logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + + if logprob.decoded_token is not None: + return logprob.decoded_token + return tokenizer.decode(token_id) + + def _is_model_supported(self, model_name): + return self.models.is_base_model(model_name) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py new file mode 100644 index 0000000000000000000000000000000000000000..f917a48519016c7300cac638796649bc79fc7a5d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import pathlib +from dataclasses import dataclass +from http import HTTPStatus +from typing import List, Optional, Union + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + ModelCard, ModelList, + ModelPermission, + UnloadLoraAdapterRequest) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import AtomicCounter + +logger = init_logger(__name__) + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]] = None, + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): + super().__init__() + + self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len + self.engine_client = engine_client + + self.static_lora_modules = lora_modules + self.lora_requests: List[LoRARequest] = [] + self.lora_id_counter = AtomicCounter(0) + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + async def init_static_loras(self): + """Loads all static LoRA modules. + Raises if any fail to load""" + if self.static_lora_modules is None: + return + for lora in self.static_lora_modules: + load_request = LoadLoraAdapterRequest(lora_path=lora.path, + lora_name=lora.name) + load_result = await self.load_lora_adapter( + request=load_request, base_model_name=lora.base_model_name) + if isinstance(load_result, ErrorResponse): + raise ValueError(load_result.message) + + def is_base_model(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) + + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + + async def show_available_models(self) -> ModelList: + """Show available models. This includes the base model and all + adapters""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest, + base_model_name: Optional[str] = None + ) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + lora_request = LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path) + if base_model_name is not None and self.is_base_model(base_model_name): + lora_request.base_model_name = base_model_name + + # Validate that the adapter can be loaded into the engine + # This will also pre-load it for incoming requests + try: + await self.engine_client.add_lora(lora_request) + except BaseException as e: + error_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + if isinstance(e, ValueError) and "No adapter found" in str(e): + error_type = "NotFoundError" + status_code = HTTPStatus.NOT_FOUND + + return create_error_response(message=str(e), + err_type=error_type, + status_code=status_code) + + self.lora_requests.append(lora_request) + logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name, + lora_path) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + logger.info("Removed LoRA adapter: name '%s'", lora_name) + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been " + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + return None + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..01a3d211f6ba633988782cbd7af6d71e556f72b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import base64 +import time +from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast + +import numpy as np +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + PoolingChatRequest, + PoolingRequest, PoolingResponse, + PoolingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +def _get_data( + output: PoolingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[List[float], str]: + if encoding_format == "float": + return output.data.tolist() + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + pooling_bytes = np.array(output.data, dtype="float32").tobytes() + return base64.b64encode(pooling_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class OpenAIServingPooling(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_pooling( + self, + request: PoolingRequest, + raw_request: Optional[Request] = None, + ) -> Union[PoolingResponse, ErrorResponse]: + """ + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"pool-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for pooling models") + + if isinstance(request, PoolingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + # In pooling requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_pooling_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + encoding_format, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_pooling_response( + self, + final_res_batch: List[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + encoding_format: Literal["float", "base64"], + ) -> PoolingResponse: + items: List[PoolingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + item = PoolingResponseData( + index=idx, + data=_get_data(final_res.outputs, encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return PoolingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..366df71217e9101c6d7b381bdf18efeef64752ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +class JinaAIServingRerank(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"rerank-{self._base_request_id(raw_request)}" + truncate_prompt_tokens = request.truncate_prompt_tokens + query = request.query + documents = request.documents + request_prompts = [] + engine_prompts = [] + top_n = request.top_n if request.top_n > 0 else len(documents) + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + for doc in documents: + request_prompt = f"{query}{tokenizer.sep_token}{doc}" + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=query, + text_pair=doc, + **tokenization_kwargs) + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_rerank_response( + final_res_batch_checked, request_id, model_name, documents, + top_n) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_rerank_response( + self, final_res_batch: List[PoolingRequestOutput], request_id: str, + model_name: str, documents: List[str], + top_n: int) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: List[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000000000000000000000000000000000000..832aa8516cc359777e5e6326276a656fe7038dab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, + ScoreResponse, ScoreResponseData, + UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], + str]) -> List: + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [t for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [t for t in text_2] + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + return [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + +class OpenAIServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"score-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + truncate_prompt_tokens = request.truncate_prompt_tokens + + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + + input_pairs = make_pairs(request.text_1, request.text_2) + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=q, + text_pair=t, + **tokenization_kwargs) + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_score_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_score_response( + self, + final_res_batch: List[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + ) -> ScoreResponse: + items: List[ScoreResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + item = ScoreResponseData( + index=idx, + score=classify_res.outputs.score, + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..6c79adf90c8ad13e9afb640c278ebdec9d6c59ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Final, List, Optional, Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_tokenize( + self, + request: TokenizeRequest, + raw_request: Request, + ) -> Union[TokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if isinstance(request, TokenizeChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + chat_template_kwargs=request.chat_template_kwargs, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + input_ids: List[int] = [] + for i, engine_prompt in enumerate(engine_prompts): + self._log_inputs(request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect + # tokenization (Unlike in Embeddings API where an error is raised) + + input_ids.extend(engine_prompt["prompt_token_ids"]) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, + request: DetokenizeRequest, + raw_request: Request, + ) -> Union[DetokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect tokenization + # (Unlike in Embeddings API where an error is raised) + + prompt_input = await self._tokenize_prompt_input_async( + request, + tokenizer, + request.tokens, + ) + input_text = prompt_input["prompt"] + + return DetokenizeResponse(prompt=input_text) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9af37871d57c8afb26a0c534db4d98938f1b8aa9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import functools + +from fastapi import Request + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcdedfffde30e05df58e99e64583f34c9862ebf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a8374fcf728ae62a7e82040720367a827b939e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9229e9c092770f53168da44dbd8b025afb33f03c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c89dcbacc9888f124e3d5b20383c21901b3fcbe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587570e96f96f5e2e72abb1d268de82b4d3a5242 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..328131d687cbf1b688f2efa32b1de71fb95a48f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87366baa902609beffdab4763e4cd2ef870b6a29 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py b/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..f379ec1682a3c99eeecbda7a08b6f9097882c920 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from io import BytesIO +from pathlib import Path + +import numpy as np +import numpy.typing as npt + +from vllm.inputs.registry import InputContext +from vllm.utils import PlaceholderModule + +from .base import MediaIO, MultiModalPlugin +from .inputs import AudioItem, ModalityData, MultiModalKwargs + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +try: + import soundfile +except ImportError: + soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] + + +class AudioPlugin(MultiModalPlugin): + """Plugin for audio data.""" + + def get_data_key(self) -> str: + return "audio" + + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[AudioItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + raise NotImplementedError("There is no default audio input mapper") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + raise NotImplementedError( + "There is no default maximum multimodal tokens") + + +def resample_audio( + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + target_sr: float, +) -> npt.NDArray[np.floating]: + return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + return librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return librosa.load(filepath, sr=None) + + def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + audio, sr = media + + with BytesIO() as buffer: + soundfile.write(buffer, audio, sr, format="WAV") + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py b/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c48d07ba365ba62a56c99842726d17a3261cc15c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, + Optional, Sequence, Tuple, Type, TypeVar, Union) + +from torch import nn + +from vllm.inputs import InputContext +from vllm.logger import init_logger +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import SequenceGroupMetadata + +from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, + PlaceholderRange) + +logger = init_logger(__name__) + +MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], + MultiModalKwargs] +""" +Return a dictionary to be passed as keyword arguments to +:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers +and processors in HuggingFace Transformers. + +If the data is not supported, throw :exc:`TypeError`. +""" + +MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] +""" +Calculate the maximum number of multimodal tokens input to the language +model. This does not include tokens that correspond to the input text. +""" + +_T = TypeVar("_T") +N = TypeVar("N", bound=Type[nn.Module]) + + +class MultiModalPlugin(ABC): + """ + Base class that defines data processing logic for a specific modality. + + In particular, we adopt a registry pattern to dispatch data processing + according to the model being used (considering that different models may + process the same data differently). This registry is in turn used by + :class:`~MultiModalRegistry` which acts at a higher level + (i.e., the modality of the data). + """ + + def __init__(self) -> None: + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() + + @abstractmethod + def get_data_key(self) -> str: + """ + Get the data key corresponding to the modality. + """ + raise NotImplementedError + + @abstractmethod + def _default_input_mapper( + self, + ctx: InputContext, + data: ModalityData[Any], + **mm_processor_kwargs, + ) -> MultiModalKwargs: + """ + Return a dictionary to be passed as keyword arguments to + :meth:`~torch.nn.Module.forward`. This is similar in concept to + tokenizers and processors in HuggingFace Transformers. + + If the data is not supported, throw :exc:`TypeError`. + """ + raise NotImplementedError + + def register_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper to a model class. + + When the model receives input data that matches the modality served by + this plugin (see :meth:`get_data_key`), the provided function is + invoked to transform the data into a dictionary of model inputs. + + If `None` is provided, then the default input mapper is used instead. + """ + + def wrapper(model_cls: N) -> N: + if self._input_mappers.contains(model_cls, strict=True): + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, + self, + ) + + self._input_mappers[model_cls] = (mapper + or self._default_input_mapper) + + return model_cls + + return wrapper + + def map_input( + self, + model_config: "ModelConfig", + data: ModalityData[Any], + mm_processor_kwargs: Optional[dict[str, Any]], + ) -> MultiModalKwargs: + """ + Transform the data into a dictionary of model inputs using the + input mapper registered for that model. + + The model is identified by ``model_config``. + + Raises: + TypeError: If the data type is not supported. + """ + + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + mapper = self._input_mappers.get(model_cls) + + if mapper is None: + raise KeyError(f"No input mapper in {self} is registered for " + f"model class {model_cls.__name__}.") + + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + # In the case of the default mapper, we have to get resource + # processor through its HuggingFace autoclass; since this goes + # through **kwargs, we can't inspect it the same way, so we allow + # drop mm_processor_kwargs based on signature inspection + # if we're using the default mapper. + # + # This should be safe in general due to the sanitation, since the + # transformers resource should filter unused kwargs anyway. + uses_default_mapper = mapper == self._default_input_mapper + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + mm_processor_kwargs, + callable=mapper, + allow_var_kwargs=uses_default_mapper, + ) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) + + @abstractmethod + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + """ + Calculate the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model. + """ + raise NotImplementedError + + def _validate_max_multimodal_tokens(self, max_mm_tokens: int): + if max_mm_tokens < 1: + raise ValueError("You should set the number of tokens to a " + f"positive integer. Found: {max_mm_tokens}") + + def register_max_multimodal_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model + for a model class. + + If `None` is provided, then the default calculation is used instead. + """ + + def wrapper(model_cls: N) -> N: + if self._max_mm_tokens.contains(model_cls, strict=True): + logger.warning( + "Model class %s already calculates maximum number of " + "tokens in %s. It is overwritten by the new one.", + model_cls, + self, + ) + + if isinstance(max_mm_tokens, int): + self._validate_max_multimodal_tokens(max_mm_tokens) + + self._max_mm_tokens[model_cls] = ( + max_mm_tokens or self._default_max_multimodal_tokens) + + return model_cls + + return wrapper + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + If this registry is not applicable to the model, `0` is returned. + + The model is identified by ``model_config``. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + from vllm.model_executor.models import supports_multimodal + + model_cls, _ = get_model_architecture(model_config) + + if not supports_multimodal(model_cls): + return 0 + + max_mm_tokens = self._max_mm_tokens.get(model_cls) + if max_mm_tokens is None: + return 0 + + if callable(max_mm_tokens): + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + max_mm_tokens, overrides=model_config.mm_processor_kwargs) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **mm_processor_kwargs) + + self._validate_max_multimodal_tokens(max_mm_tokens) + + return max_mm_tokens + + +class MultiModalPlaceholderMap: + """ + Relates multi-modal embeddings to their corresponding placeholders. + """ + + class IndexMap(NamedTuple): + src: list[int] + dest: list[int] + + src_ranges: list[range] + """ + The indices of the multi-modal embeddings that will replace the + corresponding placeholder embeddings pointed to by ``dest_ranges``. + """ + + src_len: int + """ + The total number of flattened multi-modal embeddings. + """ + + dest_ranges: list[range] + """ + The indices of the placeholder embeddings that will be replaced by the + multimodal embeddings. + """ + + dest_len: int + """ + The total number of embeddings in the destination tensor. + """ + + def __init__(self): + self.src_ranges = [] + self.src_len = 0 + self.dest_ranges = [] + self.dest_len = 0 + + @classmethod + def from_seq_group( + cls, seq_group: "SequenceGroupMetadata", positions: range + ) -> Tuple[Optional[MultiModalDataDict], dict[str, + "MultiModalPlaceholderMap"]]: + """ + Returns the multi-modal items that intersect with the portion of a + prompt (``seq_group``) represented by ``positions``, as well as a + ``MultiModalPlaceholderMap`` that relates the multi-modal embedding + vectors to their corresponding placeholders. + + Examples: + + .. code-block:: + + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| + + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | + + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | + + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] + """ + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders + + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps + + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( + MultiModalPlaceholderMap) + + for modality, placeholders in seq_mm_placeholders.items(): + mm_items = mm_data.pop(modality) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + if positions: + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) + + if intersecting_items: + mm_data[modality] = intersecting_items + + return mm_data, placeholder_maps + + def append_items_from_seq_group( + self, + positions: range, + multi_modal_items: list[_T], + multi_modal_placeholders: Sequence[PlaceholderRange], + ) -> list[_T]: + """ + Adds the multi-modal items that intersect ```positions`` to this + placeholder map and returns the intersecting items. + """ + intersecting_items = [] + + if len(multi_modal_items) != len(multi_modal_placeholders): + raise ValueError( + "Multi-modal placeholders and items must have the same length." + ) + for placeholder_dict, mm_item in zip(multi_modal_placeholders, + multi_modal_items): + placeholder = range( + placeholder_dict["offset"], + placeholder_dict["offset"] + placeholder_dict["length"], + ) + intersection = range( + max(positions.start, placeholder.start), + min(positions.stop, placeholder.stop), + ) + + if not intersection: + # Skip this multi-modal item. + continue + + token_embedding_range = range( + intersection.start - positions.start, + intersection.stop - positions.start, + ) + + multimodal_embedding_range = range( + intersection.start - placeholder.start + self.src_len, + intersection.stop - placeholder.start + self.src_len, + ) + + intersecting_items.append(mm_item) + self.dest_ranges.append(token_embedding_range) + self.src_ranges.append(multimodal_embedding_range) + self.src_len += len(placeholder) + + self.dest_len += len(positions) + return intersecting_items + + def extend(self, other: "MultiModalPlaceholderMap"): + """ + Adds the placeholders from another ``MultiModalPlaceholderMap`` to this + instance based on the source and destination tensors being + concatenated. + """ + + self.src_ranges.extend( + range(self.src_len + r.start, self.src_len + r.stop) + for r in other.src_ranges) + self.src_len += other.src_len + self.dest_ranges.extend( + range(self.dest_len + r.start, self.dest_len + r.stop) + for r in other.dest_ranges) + self.dest_len += other.dest_len + + def index_map(self) -> "IndexMap": + """ + Finalizes the placeholder map into lists of indices that can be used to + index the source and destination tensors. + """ + + src_indices = [i for r in self.src_ranges for i in r] + dest_indices = [i for r in self.dest_ranges for i in r] + + if len(src_indices) != len(dest_indices): + raise ValueError( + f"The number of source ({len(src_indices)}) and destination " + f"indices ({len(dest_indices)}) must be the same.") + + return MultiModalPlaceholderMap.IndexMap(src=src_indices, + dest=dest_indices) + + +class MediaIO(ABC, Generic[_T]): + + @abstractmethod + def load_bytes(self, data: bytes) -> _T: + raise NotImplementedError + + @abstractmethod + def load_base64(self, media_type: str, data: str) -> _T: + """ + List of media types: + https://www.iana.org/assignments/media-types/media-types.xhtml + """ + raise NotImplementedError + + @abstractmethod + def load_file(self, filepath: Path) -> _T: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f256106889c11c5316ad08a164588d139dc1c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588f098bf16ab7733f75e03c2931b31ff4dc5490 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f2116ac60728b052602ad23b4846a6a438ecb9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0f956c79818fff834c6c21bd278d0a7dea0e28b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb28c8c22f9646cce2096f0876e9e961173050d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82b1c11633029b0762fc125c7d1a7ab43240be15 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f09a75eb579000d92e1db51ebffbdbb743a6e042 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa3bf8ea708e2b4a71e15ef051f18bdef7f90831 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a693ea92af46d2b4b739da83318433fe23cc1d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..818ca506c13fb927e7636afb04376de86a9a90bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..236572f676cde82037968ac43923cf3c3c5243d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e297b6b73ea26fc0d584c0973373c56f082ffe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c943e829eaf96434b6d103f37200b47292a23f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b403889fe3a3da9f40a1ce104aa5b87947755b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e06c137d95a00c96e79f2c3fb382cc7de6cc31b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0832ee69b175bc3f3edf13e557e3f8737c490eab Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..266d53f2698e71892c336661f196e5c993db3aaa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py new file mode 100644 index 0000000000000000000000000000000000000000..e08ed742a5225186880dc60dc86017bcdc334bd7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 + +from array import array +from itertools import chain, count +from typing import Iterator, List, Optional, Tuple + +import torch + +from vllm import SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len + +SeqId = int +TargetSeqId = int +TokenId = int + +DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() + + +class BatchExpansionTop1Scorer(SpeculativeScorer): + """Implements a speculative scorer that uses batch expansion to get + probabilities of speculative tokens according to the scoring model. + + Batch expansion converts a list of sequences and multiple query positions + to a new batch of sequences, each with a single query position. This allows + for MQA-like scoring in speculative decoding without requiring an MQA + kernel. + + It is strictly less efficient than MQA scoring. + + It only supports scoring the top1 proposal tokens of the proposer, instead + of topk/tree. + """ + + @nvtx_range("BatchExpansionTop1Scorer.score_proposals") + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Score the proposed tokens via the scorer model. + + This converts each input sequence to a set of k+1 target sequences. The + target sequences have the unique continuations to be scored and a + unique sequence ID that is different from all input sequence ids. + + If a speculative sequence length would exceed the max model length, then + no speculation is produced for that sequence. + + Args: + execute_model_req: The execution request. + proposals: The speculative proposals to score. + Returns: + SpeculativeScores: The scores of each speculative token, along with + which sequences were ignored during scoring. + """ + + # TODO(cade) perform this on GPU to remove blocking call. + proposal_lens_list = proposals.proposal_lens.tolist() + proposal_token_ids_list = proposals.proposal_token_ids.tolist() + + # Filter the list to ignore invalid proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if VLLM_INVALID_TOKEN_ID not in proposals + ] + + (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) = self._expand_batch( + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, + proposal_lens_list=proposal_lens_list, + ) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] + + if not non_spec_indices: + # All sequence groups in batch have spec decoding enabled + return self._contract_batch_all_spec( + target_sampler_output=target_sampler_output, + proposals=proposals, + ) + else: + # Batch has a mix of spec decode enabled and disabled seq groups + return self._contract_batch( + execute_model_req.seq_group_metadata_list, + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) + + def _expand_batch( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids_list: List[List[TokenId]], + proposal_lens_list: List[int], + ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: + """Given the input sequences and potentially multiple corresponding + proposal tokens, create a new batch where each sequence has a single + query token. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ + split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + + spec_expanded_seqs = self._create_scoring_model_input( + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + + num_scoring_tokens = len(spec_expanded_seqs) + # Batch speculative and non-speculative (e.g. chunked prefill) requests + # but make sure order is prefill|decode due to backend requirement. + target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs + + return (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) + + def _contract_non_speculative( + self, scores: SpeculativeScores, + seq_group_metadata_list: List[SequenceGroupMetadata], + non_spec_indices: List[int], non_spec_outputs: SpeculativeScores, + has_prompt_log: bool) -> SpeculativeScores: + """ + Augment input `scores` with non-speculative requests outputs. + This includes decode requests with speculation turned off, as well + as prefill requests when `enable_chunked_prefill` is set. + For the latter, prefills are further separated into terminal and + non-terminal chunks (from which no token is sampled). + """ + if not non_spec_indices: + return scores + + if has_prompt_log: + # When prompt_logprobs is enabled, prefills yield output token + # (and respective prob) in the last entry (prompt|out): + # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..]. + # With chunked prefill, non-terminal chunks have -1 on each + # position: they're still picked, but they're discarded later. + seq_meta = seq_group_metadata_list + nospec_sizes = torch.tensor([ + seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 + for i in non_spec_indices + ]) + nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1) + else: + # In this case only sampled tokens are returned, select all. + nospec_sampled_token_idxs = list( + range(len(non_spec_outputs.token_ids))) + + scores.token_ids[non_spec_indices, :1] = \ + non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1) + scores.probs[non_spec_indices, :1, :] = \ + non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1) + scores.logprobs[non_spec_indices, :1, :] = \ + non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1) + if scores.hidden_states is not None: + assert non_spec_outputs.hidden_states is not None + scores.hidden_states[non_spec_indices, :1, :] = \ + non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1) + return scores + + def _contract_batch( + self, + contracted_seq_group_metadata_list: List[SequenceGroupMetadata], + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> SpeculativeScores: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. + """ + contracted_bs = len(contracted_seq_group_metadata_list) + (target_token_ids, target_probs, target_logprobs, target_hidden_states, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences, prefill chunks with no out tokens included + non_spec_expanded_bs = len(non_spec_indices) + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs + + target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) + target_probs = target_probs.reshape(*target_token_ids.shape, + self._vocab_size) + target_logprobs = target_logprobs.reshape(target_probs.shape) + + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in contracted_seq_group_metadata_list) + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is. + prompt_logprobs = None + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + elif not has_prompt_log: + # When prompt logprobs are not to be returned, + # we can ignore non-terminal chunks (no out token). + non_spec_indices = [ + idx for idx in non_spec_indices + if contracted_seq_group_metadata_list[idx].do_sample + ] + + # "Contract" speculative. + if spec_indices: + all_tokens[spec_indices] = target_token_ids + all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + spec_scores = SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=all_hidden_states, + prompt_logprobs=prompt_logprobs) + + non_spec_outputs = SpeculativeScores( + probs=non_spec_target_probs, + token_ids=non_spec_target_token_ids, + logprobs=non_spec_target_logprobs, + hidden_states=non_spec_target_hidden_states) + # Contract remaining nonspec entries based on non_spec_indices, if any. + return self._contract_non_speculative( + spec_scores, contracted_seq_group_metadata_list, non_spec_indices, + non_spec_outputs, has_prompt_log) + + def _contract_batch_all_spec( + self, + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + It assumes all sequences in the batch were previously expanded. + """ + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + contracted_bs, k = proposals.proposal_token_ids.shape + + # Reshape tensors to original batch size + target_token_ids = target_sampler_output.sampled_token_ids.reshape( + contracted_bs, k + 1) + target_probs = target_sampler_output.sampled_token_probs.reshape( + *target_token_ids.shape, self._vocab_size) + target_logprobs = target_sampler_output.logprobs.reshape( + target_probs.shape) + target_hidden_states = target_sampler_output.hidden_states + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + return SpeculativeScores(probs=target_probs, + token_ids=target_token_ids, + logprobs=target_logprobs, + hidden_states=target_hidden_states, + prompt_logprobs=None) + + def _create_scoring_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given the original input sequences and proposed tokens from the draft + model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. + """ + + if not seq_group_metadata_list: + return [] + + target_seq_group_metadata = list( + chain.from_iterable( + self._create_target_seq_group_metadata( + seq_group_metadata, + proposal_token_ids, + i, + target_seq_ids_iter, + ) for i, seq_group_metadata in enumerate( + seq_group_metadata_list))) + + return target_seq_group_metadata + + def _create_target_seq_group_metadata( + self, + input_seq_group_metadata: SequenceGroupMetadata, + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + batch_index: int, + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given an input sequence group metadata and a list of draft tokens, + create a list of target SequenceGroupMetadata, one for each + token id that needs to be scored. + + Naive speculative decoding requires K target model scores, one for each + draft model token. However one can add a bonus token such that if each + token is accepted, then a final token may be sampled from the model. + This function creates K+1 target SequenceGroupMetadata to take + advantage of the bonus token. + """ + assert len(input_seq_group_metadata.seq_data) == 1, ( + "Beam search " + "not supported in speculative decoding") + input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) + + token_ids_to_score = self._get_token_ids_to_score( + proposal_token_ids[batch_index]) + + sampling_params = input_seq_group_metadata.sampling_params + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, token_ids in enumerate(token_ids_to_score): + target_seq_group_metadata_list.append( + self._create_single_target_seq_group_metadata( + input_seq_group_metadata, + input_seq_id, + next(target_seq_ids_iter), + token_ids, + sampling_params=sampling_params, + )) + + return target_seq_group_metadata_list + + @staticmethod + def _create_single_target_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, + seq_id: SeqId, + target_seq_id: TargetSeqId, + token_ids: List[TokenId], + sampling_params: SamplingParams, + ) -> SequenceGroupMetadata: + """Create a single target SequenceGroupMetadata. + + Args: + seq_group_metadata: The metadata for the input sequence. + seq_id: The input sequence ID. + target_seq_id: The corresponding target sequence ID. + token_ids: The list of token ids that are to be appended to the + input sequence. + """ + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_token_ids = seq_data.prompt_token_ids_array + new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + mrope_position_delta = seq_data.mrope_position_delta + + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids, + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids), + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + data.mrope_position_delta = mrope_position_delta + + return SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + + @staticmethod + def _split_scoring_output( + sampler_output: SamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + """Split the target model output into speculative and non-speculative + output. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + # + # First samples are non-speculative, latter samples are from speculative + # scoring (prefill|decode order). + split_sizes = (sampler_output.sampled_token_ids.numel() - + num_scoring_tokens, num_scoring_tokens) + (non_spec_probs, + spec_probs) = sampler_output.sampled_token_probs.split(split_sizes) + (non_spec_sampled_tokens, spec_sampled_tokens + ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + (non_spec_logprobs, + spec_logprobs) = sampler_output.logprobs.split(split_sizes) + + if sampler_output.hidden_states is not None: + (non_spec_hidden_states, spec_hidden_states + ) = sampler_output.hidden_states.split(split_sizes) + else: + non_spec_hidden_states, spec_hidden_states = None, None + + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + + @staticmethod + def _create_target_seq_id_iterator( + seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) + + @staticmethod + def _get_token_ids_to_score( + full_spec_token_ids: List[TokenId] # shape: [k] + ) -> List[List[TokenId]]: + """Given an int tensor of proposal token ids, return a list of + token ids that should be scored. + + Returns k+1 output lists. The additional one is used for generating the + bonus token. + + Example: + Input: [0, 1, 2, 3] (k=4) + Output: (k+1 lists) + [] + [0] + [0, 1] + [0, 1, 2] + [0, 1, 2, 3] + """ + empty_token_ids: List[TokenId] = [] + + token_ids_to_score = [empty_token_ids] + token_ids_to_score.extend(full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids))) + return token_ids_to_score diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..dd085ad77638462535cf2c3d9def11e8647de965 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Set, Union + +import torch + +from vllm.sequence import ExecuteModelRequest, PromptLogprobs +from vllm.worker.worker_base import WorkerBase + + +@dataclass +class SpeculativeProposals: + """Datastructure used to represent proposal tokens from some proposer. It + also tracks how many speculative tokens each sequence has. + """ + + # Speculative proposal tokens. + proposal_token_ids: torch.Tensor + + # Probabilities of the proposal tokens according to the proposer. + proposal_probs: torch.Tensor + + # The valid length of each proposal; can be zero. + proposal_lens: torch.Tensor + + # A flag to mark that there's no available proposals + no_proposals: bool = False + + def __repr__(self): + return (f"SpeculativeProposals(" + f"proposal_token_ids={self.proposal_token_ids}, " + f"proposal_probs={self.proposal_probs.shape}, " + f"proposal_lens={self.proposal_lens})") + + +@dataclass +class SpeculativeScores: + """Datastructure used to represent the scores of speculative tokens + according to the scoring model. + """ + + # Probabilities of the speculative tokens according to the scoring model. + probs: torch.Tensor + + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + + # Token ids sampled from the scoring model. Used for speculative bonus + # tokens and also non-speculative normal decoding. + token_ids: torch.Tensor + + # Optional last hidden states from the scoring model. + hidden_states: Optional[torch.Tensor] = None + + # Scoring model may also return logprobs for prompt tokens + # for each request, when chunked prefill is enabled. + prompt_logprobs: Optional[List[PromptLogprobs]] = None + + def __repr__(self): + return (f"SpeculativeScores(" + f"probs={self.probs.shape}, " + f"token_ids={self.token_ids.shape})") + + +class SpeculativeProposer(ABC): + + @abstractmethod + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # If set, this contains all sequence IDs that were assigned + # bonus tokens in their last forward pass. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + raise NotImplementedError + + +class SpeculativeScorer(ABC): + + def __init__(self, scorer_worker: WorkerBase, + device: Union[torch.device, str], vocab_size: int): + self._scorer_worker = scorer_worker + if isinstance(device, torch.device): + device = device.type + self._device = device + self._vocab_size = vocab_size + + @abstractmethod + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..0b62a988e8b267aed2ecad09cba59eec38808b74 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +import weakref +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import DelegateWorkerBase + + +class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): + """Worker for Medusa. + """ + + def __init__(self, *args, **kwargs): + DelegateWorkerBase.__init__(self, *args, **kwargs) + # Lazy initialization list. + self._proposer: Top1Proposer + + def init_device(self): + self.worker.init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self): + pass + + def set_should_modify_greedy_probs_inplace(self): + pass + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For medusa worker, this indicator shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + seq_lens, query_lens = self._prepare_input_tensors( + seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + sampling_metadata=sampling_metadata) + + return model_outputs, False + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[List[int], List[int]]: + if not seq_group_metadata_list: + return [], [] + + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + seq_lens.append(seq_len) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + query_lens.append(1) + + return seq_lens, query_lens + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MedusaWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MedusaWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MedusaWorker does not support beam search.") diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..bdaf31895e25dee2e8ec47ff6f0a41f2c208f623 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase + + +class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): + """Worker for MLPSpeculator models. + + Not currently compatible with LoRA or chunked prefill. + """ + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For mlp spec worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + (input_tokens, seq_lens, + query_lens) = self._prepare_input_tensors(seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + input_ids=input_tokens, + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + num_predict_tokens=sample_len, + sampling_metadata=sampling_metadata) + + assert len(model_outputs) == sample_len + + return model_outputs, True + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, List[int], List[int]]: + if not seq_group_metadata_list: + return torch.empty(0, device=self.device), [], [] + + input_tokens: List[int] = [] + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + input_tokens.extend(tokens) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + input_tokens.append(seq_data.get_last_token_id()) + query_lens.append(1) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + return input_tokens_tensor, seq_lens, query_lens diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..6275c460ecefa0aaca2fe2d6be7e3dc90ccd3aa0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() + all_proposal_lengths = proposals.proposal_lens.tolist() + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + if all_proposal_lengths[i] == 0: + # Keep prompt seqs untouched (keep computed_tokens for chunks). + target_seq_group_metadata_list.append(seq_group_metadata) + continue + + seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = all_proposal_tokens[ + i][:all_proposal_lengths[i]] + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = target_seq_id_start + i + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + + # Ensure that the new decode sequence has at least one token. + assert len(output_token_ids) >= 1 + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + target_sampler_output = target_sampler_output[0] + + k = execute_model_req.num_lookahead_slots + bs = len(execute_model_req.seq_group_metadata_list) + target_token_ids = target_sampler_output.sampled_token_ids + target_probs = target_sampler_output.sampled_token_probs + target_logprobs = target_sampler_output.logprobs + prompt_logprobs = None + + # If all requests have the same number of query tokens, we can avoid + # the for loop to build output for better performance. + if min(all_proposal_lengths) == k: + # Regular decodes only. + assert all(not sg.is_prompt + for sg in target_seq_group_metadata_list + if sg.is_prompt) + bs, _ = proposals.proposal_token_ids.shape + all_tokens = target_token_ids.reshape(bs, k + 1) + all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) + all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) + else: + # We either have decodes with different lens or prefill+decodes. + all_tokens = target_token_ids.new_full(size=(bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, + self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + target_token_ids = target_token_ids.flatten() + + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is, since it may be + # that n_prompts >> K. + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in target_seq_group_metadata_list) + # TODO (NickLucche) we should surface `disable_logprobs` as to not + # break abstraction to get its value. + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + + # Split loop into prefill|decode for readability. + start_loc, i = 0, 0 + while i < len(target_seq_group_metadata_list + ) and target_seq_group_metadata_list[i].is_prompt: + seq_meta = target_seq_group_metadata_list[i] + end_loc = start_loc + if has_prompt_log: + end_loc += seq_meta.token_chunk_size + elif seq_meta.do_sample: + end_loc += 1 + + # Skip chunks with no output tokens. + if seq_meta.do_sample: + # Get sampled token (last position in chunk) and its prob. + all_tokens[i, 0] = target_token_ids[end_loc - 1] + all_probs[i, 0] = target_probs[end_loc - 1] + all_logprobs[i, 0] = target_logprobs[end_loc - 1] + + i += 1 + start_loc = end_loc + # Decodes. + while i < len(target_seq_group_metadata_list): + proposed_len, seq_meta = all_proposal_lengths[ + i], target_seq_group_metadata_list[i] + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc + i += 1 + + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) + + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=hidden_states, + prompt_logprobs=prompt_logprobs) diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..5474917a6fab7f436cd2e8905e9777de0abc0727 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import weakref +from typing import Dict, List, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import DelegateWorkerBase + + +class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less. + + The MultiStepWorker does not support cache swap operations, or beam search. + Cache swap operations do not require large modifications. On the other hand, + beam search requires memory allocations during sequence forks and thus + requires more thought for MultiStepWorker support. + """ + + def __init__(self, *args, **kwargs): + DelegateWorkerBase.__init__(self, *args, **kwargs) + # Lazy initialization list. + self._proposer: SpeculativeProposer + + def init_device(self) -> None: + self.worker.init_device() + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self) -> None: + # Need include_gpu_probs_tensor for MultiStepWorker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + + def set_should_modify_greedy_probs_inplace(self) -> None: + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True) + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + # Expand the batch for sequences with a bonus token. + # Perform a forward pass on the expanded batch and filter the + # response to retain only the original sequences' responses. + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + # Run model sample_len times. + model_outputs: List[SamplerOutput] = [] + if current_platform.is_cuda_alike() and isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly + expanded_request.num_steps = sample_len + self.model_runner.set_indices_of_seq_with_bonus_tokens( + indices_of_seq_with_bonus_tokens) + model_outputs = self.execute_model( + execute_model_req=expanded_request) + else: + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) + for _ in range(sample_len): + model_output: List[SamplerOutput] = self.worker.execute_model( + execute_model_req=expanded_request) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] + + self._append_new_tokens( + model_output, expanded_request.seq_group_metadata_list, + indices_of_seq_with_bonus_tokens) + model_outputs.append(model_output) + + # move indices to device to avoid stream sync + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True + + @staticmethod + def _expand_execute_model_request( + execute_model_req: ExecuteModelRequest, + seq_with_bonus_token_in_last_step: set, + ) -> Tuple[ExecuteModelRequest, List[int]]: + """ + Expands the execute model request based on sequences with bonus + tokens. + + For each sequence with a bonus token, this method creates a new + sequence without the bonus token and adds it to the execute model + request. The original sequence groups are also retained. The indices + of the original sequence groups are returned for further processing. + + Args: + execute_model_req (ExecuteModelRequest): The original execute + model request. + seq_with_bonus_token_in_last_step (set): Set of sequence IDs that + contain bonus tokens. + + Returns: + Tuple[ExecuteModelRequest, List[int]]: The updated execute model + request with expanded sequences and a list of indices corresponding + to the original sequence groups. + """ + updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + updated_execute_model_req = execute_model_req.clone( + updated_seq_group_metadata_list) + indices_of_original_sequence_groups = [] + for seq_group in execute_model_req.seq_group_metadata_list: + seq_group_has_bonus_tokens = False + for seq_id, _ in seq_group.seq_data.items(): + # Identify sequences with bonus tokens in the sequence group. + if seq_id in seq_with_bonus_token_in_last_step: + seq_group_has_bonus_tokens = True + break + if seq_group_has_bonus_tokens: + #Create new sequences without the last bonus token. These new + # sequence have the same sequence id as the original sequence. + # We create a new sequence group and add them there. + updated_seq_group_without_bonus_token = \ + MultiStepWorker._copy_seq_metadata_excluding_last_token( + seq_group, seq_with_bonus_token_in_last_step) + updated_seq_group_metadata_list.append( + updated_seq_group_without_bonus_token) + # Add the original sequence group. + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) + # Record the index of the original sequence group. + indices_of_original_sequence_groups.append( + len(updated_seq_group_metadata_list) - 1) + + updated_execute_model_req.seq_group_metadata_list =\ + updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + + return updated_execute_model_req, indices_of_original_sequence_groups + + @staticmethod + def _filter_model_output( + expanded_batch_outputs: List[SamplerOutput], + output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: + """ + Filters the model output to include only the specified sequence + outputs. This method contracts the expanded batch output from the + model to retain the outputs of only those sequences indicated by the + provided indices. + + Args: + expanded_batch_output (List[SamplerOutput]): The expanded output + batch from the model. + output_indices_to_retain (torch.Tensor): Indices of the model + outputs to retain. + + Returns: + List[SamplerOutput]: A list containing the filtered model + outputs for the specified indices. + """ + return [ + SamplerOutput( + outputs=[ + expanded_batch_output.outputs[i] + for i in output_indices_to_retain + ] if len(expanded_batch_output.outputs) > 0 else [], + sampled_token_probs=( + expanded_batch_output. + sampled_token_probs[output_indices_to_retain] + if expanded_batch_output.sampled_token_probs is not None + else None), + logprobs=( + expanded_batch_output.logprobs[output_indices_to_retain] + if expanded_batch_output.logprobs is not None else None), + sampled_token_ids=(expanded_batch_output. + sampled_token_ids[output_indices_to_retain] + if expanded_batch_output.sampled_token_ids + is not None else None)) + for expanded_batch_output in expanded_batch_outputs + ] + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: set, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + @staticmethod + def _append_new_tokens( + model_output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + indices_of_seq_with_bonus_tokens: List[int]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done outside of the worker, but it is + required if the worker is to perform multiple forward passes. + """ + count = 0 + for index, (seq_group_metadata, sequence_group_outputs) in enumerate( + zip(seq_group_metadata_list, model_output)): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + # Determine the actual token ID to be generated, + # considering bonus tokens + if index != indices_of_seq_with_bonus_tokens[count]: + bonus_seq_metadata = seq_group_metadata_list[ + indices_of_seq_with_bonus_tokens[count]] + _, bonus_token_seq_data = next( + iter(bonus_seq_metadata.seq_data.items())) + token_id = bonus_token_seq_data.output_token_ids[-1] + else: + count += 1 + + seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) + + @staticmethod + def _shallow_copy_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: + """Copy input data structures to remove side-effects when input data + structures are shared with other modules. + + Helpful when the vLLM scheduler runs in the same process as the worker. + The alternative is deep-copying (or other form of deep copy); this has + performance downsides. + """ + # Shallow-copy the SequenceGroupMetadata. This allows us to + # append tokens and change is_prompt without external side-effects. + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) + + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:] + + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + @staticmethod + def _copy_seq_metadata_excluding_last_token( + seq_group_metadata: SequenceGroupMetadata, + seq_ids_to_copy: Set[int], + ) -> SequenceGroupMetadata: + """ + Creates a shallow copy of the given SequenceGroupMetadata, retaining + only the sequence IDs specified in seq_ids_to_copy. For each of these + sequence IDs, all output_token_ids except the last one are copied. + Sequence IDs not in seq_ids_to_copy are excluded from the copy. + + Parameters: + seq_group_metadata (SequenceGroupMetadata): The original sequence + group metadata. + seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the + copy. + + Returns: + SequenceGroupMetadata: A shallow copy of the sequence group metadata + with the specified modifications. + """ + # Shallow-copy the SequenceGroupMetadata. + new_seq_group_metadata = copy.copy(seq_group_metadata) + # Shallow-copy seq_data and modify the output_token_ids. + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + if (seq_id in seq_ids_to_copy): + new_seq_data[seq_id] = copy.copy(old_seq_data) + # Copy all the output token ids except the last. + # Also reduce num_computed_tokens by 1 since we are not + # including the last output token. + # NOTE: num_computed_tokens is not directly used by the + # speculative decoding workers, as it is only relevant for + # chunked prefill, which is disabled for speculative decoding. + # However, to maintain consistency in num_computed_tokens, + # we update it here. + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:-1] + new_seq_data[seq_id].update_num_computed_tokens(-1) + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.model_runner.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = (number_physical_blocks * + self.model_runner.block_size) + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..86390c99c2fbced6163eac2374cac7afe681b602 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 + +import weakref +from typing import List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer + + +class _DummyModel(nn.Module): + pass + + +class NGramWorker(NonLLMProposerWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implements prompt lookup decoding, + and in future we may also do RAG type drafter and other scenarios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() + self.device_type = kwargs.get("device_type", "cuda") + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"{self.device_type}:{self.local_rank}") + + # Current NGramWorker only supports Top1Proposer + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + device=self.device, + vocab_size=self.vocab_size, + ) + + def load_model(self) -> None: + pass # Dummy + + def get_model(self) -> nn.Module: + return _DummyModel() + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + has_spec_out = False + token_id_list: List[Optional[torch.Tensor]] = [] + token_prob_list: List[Optional[torch.Tensor]] = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + seq_len = seq_data.get_len() + # When seq_len is less than 3072 (3K), we use CPU to perform + # the ngram match. Otherwise, we use the device specified in + # the model config (normally GPU). 3072 is a rough threshold + # based on profiling on H100, and it can be adjusted based + # on the actual performance on different hardware. + cur_device = "cpu" if seq_len < 3072 else self.device + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=cur_device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min - 1, + -1, + ): + ngram_tensor = input_ids[-ngram_size:] + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=cur_device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, + index=spec_indices).to(self.device) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) + has_spec_out = True + break + else: + token_id_list.append(None) + token_prob_list.append(None) + + if not has_spec_out: + return None, False + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) + + return outputs, False + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2bebf80fadae5e3e637053f95740340bd6a98f7f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import List, Optional, Set, Tuple + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): + """Interface for proposer workers""" + + @abstractmethod + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # A set containing all sequence IDs that were assigned bonus tokens + # in their last forward pass. This set is used to backfill the KV cache + # with the key-value pairs of the penultimate token in the sequences. + # This parameter is only used by the MultiStepWorker, which relies on + # the KV cache for token generation. It is not used by workers that + # do not utilize the KV cache. + seq_ids_with_bonus_token_in_last_step: Set[int] + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + raise NotImplementedError + + def set_include_gpu_probs_tensor(self) -> None: + """Implementation optional""" + pass + + def set_should_modify_greedy_probs_inplace(self) -> None: + """Implementation optional""" + pass + + +class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): + """Proposer worker which does not use a model with kvcache""" + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """get_spec_proposals is used to get the proposals""" + return [] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """This is never called on the proposer, only the target model""" + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + return 0 diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a1466ba5db756d59f9ea4e709d27a343fd0943a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.distributed.parallel_state import (get_tp_group, + init_model_parallel_group, + patch_tensor_parallel_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase + +logger = init_logger(__name__) + + +class _DummyModel(nn.Module): + pass + + +class SmallerTpProposerWorker(ProposerWorkerBase): + """Class which allows a speculative draft model to run with smaller tensor + parallel degree than target model. + This reduces the communication overhead of small draft models. + + To implement this feature, this class differs behavior based on is_dummy + flag, where dummy means worker that does not participate draft generation. + Participating workers use a smaller tp group by patching vLLM's tensor + parallel group temporarily during forward passes of draft models. + """ + + @classmethod + def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int, + target_tensor_parallel_size: int): + """Wrap the worker in a SmallerTpProposerWorker if necessary. + """ + if draft_tensor_parallel_size == target_tensor_parallel_size: + return worker + + # gpu ranks that will generate draft tokens together + draft_ranks = list(range(draft_tensor_parallel_size)) + + logger.info("Wrapping {%s} in {%s}", type(worker), cls) + return cls(worker, draft_ranks) + + def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]): + """Create a SmallerTpProposerWorker. + + Args: + worker (MultiStepWorker): an actual worker wrapped with this class + draft_ranks (List[int]): if this value is given, only the GPU ranks + written in this value participate in draft generation + """ + self._worker = worker + self._draft_ranks = draft_ranks + + # init during init_device + self._is_dummy = False + self._tp_group = None + + def _patch_tensor_parallel_group(self): + """Temporarily patch the global tp group state with its own tp group + state. + """ + return patch_tensor_parallel_group(self._tp_group) + + def init_device(self) -> None: + self._is_dummy = get_tp_group().rank not in self._draft_ranks + + # dummy workers do nothing + if self._is_dummy: + return + + # creates tp process group containing only a subset of gpu ranks + local_rank = get_tp_group().local_rank + tp_backend = torch.distributed.get_backend(get_tp_group().device_group) + self._tp_group = init_model_parallel_group([self._draft_ranks], + local_rank, tp_backend) + + with self._patch_tensor_parallel_group(): + self._worker.init_device() + + def set_include_gpu_probs_tensor(self) -> None: + if self._is_dummy: + return + + # Need include_gpu_probs_tensor for multi_step_worker + self._worker.set_include_gpu_probs_tensor() + + def set_should_modify_greedy_probs_inplace(self) -> None: + if self._is_dummy: + return + + self._worker.set_should_modify_greedy_probs_inplace() + + def load_model(self) -> None: + if self._is_dummy: + return + + with self._patch_tensor_parallel_group(): + self._worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + if self._is_dummy: + # this case is not used now + return -1, -1 + + with self._patch_tensor_parallel_group(): + return self._worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + if self._is_dummy: + return + + with self._patch_tensor_parallel_group(): + self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + # Do not check _is_dummy, as it's always called by get_spec_proposals + return self._worker.sampler_output( + execute_model_req, sample_len, + seq_ids_with_bonus_token_in_last_step) + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + if self._is_dummy: + return SpeculativeProposals(None, None, None) + + with self._patch_tensor_parallel_group(): + return self._worker.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def get_model(self) -> nn.Module: + if self._is_dummy: + return _DummyModel() + + with self._patch_tensor_parallel_group(): + return self._worker.get_model() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if self._is_dummy: + return [] + + with self._patch_tensor_parallel_group(): + return self._worker.execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + if self._is_dummy: + # by returning zero, target worker can use the entire kv cache space + return 0 + + return self._worker.get_cache_block_size_bytes() + + @property + def vocab_size(self) -> int: + return self._worker.vocab_size diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..8653bece8b5a59b616f41cad0bc8f4b201f6ac06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py @@ -0,0 +1,1282 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +from collections import defaultdict +from functools import cached_property +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +import torch +import torch.nn as nn + +from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig +from vllm.distributed.communication_op import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.platforms import current_platform +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, + HiddenStates, SequenceGroupMetadata, + get_all_seq_ids_and_request_ids) +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.medusa_worker import MedusaWorker +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker +from vllm.spec_decode.target_model_runner import TargetModelRunner +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, + get_all_num_logprobs, + get_sampled_token_logprobs, nvtx_range, + split_batch_by_proposal_len) +from vllm.utils import resolve_obj_by_qualname +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) + + +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + vllm_config: VllmConfig = kwargs.get("vllm_config") + speculative_config: SpeculativeConfig = vllm_config.speculative_config + assert speculative_config is not None + + if vllm_config.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError("Speculative decoding is currently " + "incompatible with pipeline parallelism") + + draft_worker_kwargs = kwargs.copy() + + kwargs["model_runner_cls"] = TargetModelRunner + target_worker_config = copy.deepcopy(vllm_config) + target_worker_config.parallel_config.worker_cls =\ + target_worker_config.parallel_config.sd_worker_cls + cls = resolve_obj_by_qualname( + target_worker_config.parallel_config.worker_cls) + target_worker = cls(*args, **kwargs) + # Set the disable_logprobs variable in the TargetModelRunner instance + # as per its value specified in the SpeculativeConfig. + target_worker.model_runner.disable_logprobs =\ + speculative_config.disable_logprobs + + draft_worker_config = copy.deepcopy(vllm_config) + draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.quant_config = VllmConfig._get_quantization_config( + draft_worker_config.model_config, + vllm_config.load_config, + ) + speculative_config.draft_parallel_config.worker_cls =\ + draft_worker_config.parallel_config.sd_worker_cls + draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa + # TODO allow draft-model specific load config. + + # Override draft-model specific worker args. + draft_worker_kwargs.update( + vllm_config=draft_worker_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + draft_token_acceptance_method=speculative_config. + draft_token_acceptance_method, + typical_acceptance_sampler_posterior_threshold=speculative_config. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=speculative_config. + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) + + return spec_decode_worker + + +# Reminder: Please update docs/source/features/compatibility_matrix.md +# If the feature combo become valid +class SpecDecodeWorker(LoraNotSupportedWorkerBase): + """Worker which implements speculative decoding. + + Speculative decoding reduces decoding per-token latency by using a proposal + method, such as a small draft model, to speculate ahead of a larger LLM. The + probabilities of the speculative tokens are then determined by the larger + LLM, after which some verification routine determines which (if any) of the + speculative tokens are accepted by the larger LLM. + + See https://github.com/vllm-project/vllm/pull/2188 and + https://github.com/vllm-project/vllm/pull/3103 for more info. + + The current implementation has the following limitations: + * Only draft-model proposal is implemented (contributions for more forms are + welcome!). + * Only top-1 proposal and scoring are implemented. Tree-attention is left as + future work. + * All sequences in a batch must have the same proposal length, or zero. This + can be improved by having per-sequence speculation in the future. + * The scoring forward pass is done without an MQA kernel, which is + suboptimal especially as the batch size, proposal length, and sequence + lengths grow. Contributions to add a MQA scoring are welcome once + correctness tests pass. + More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. + """ + + @classmethod + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, + disable_by_batch_size: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, + disable_log_stats: bool, + ) -> "SpecDecodeWorker": + + allow_zero_draft_token_step = True + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config + if ngram_prompt_lookup_max > 0: + draft_worker_kwargs[ + "device_type"] = scorer_worker.device_config.device.type + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + draft_tp = draft_parallel_config.tensor_parallel_size + target_tp = scorer_worker.parallel_config.tensor_parallel_size + + if draft_model_config.hf_config.model_type == "mlp_speculator": + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_model_config.hf_config.model_type == "medusa": + proposer_worker = MedusaWorker(**draft_worker_kwargs) + else: + if draft_tp == 1: + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner + else: + if draft_model_config.hf_config.model_type == "eagle": + raise NotImplementedError( + "EAGLE does not support TP > 1 yet") + + allow_zero_draft_token_step = False + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + + proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( + proposer_worker, draft_tp, target_tp) + + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + + spec_decode_sampler: SpecDecodeBaseSampler = None + if draft_token_acceptance_method == "rejection_sampler": + spec_decode_sampler = RejectionSampler() + elif draft_token_acceptance_method == "typical_acceptance_sampler": + spec_decode_sampler = TypicalAcceptanceSampler( + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "FLASH_ATTN": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if draft_model_config and \ + draft_model_config.max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") + + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step) + + def __init__( + self, + proposer_worker: ProposerWorkerBase, + scorer_worker: WorkerBase, + spec_decode_sampler: SpecDecodeBaseSampler, + disable_mqa_scorer: bool = False, + disable_logprobs: bool = False, + disable_log_stats: bool = False, + metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, + allow_zero_draft_token_step: Optional[bool] = True, + ): + """ + Create a SpecDecodeWorker. + + Args: + proposer_worker: A worker that can produce speculative tokens for + sequences. + scorer_worker: A worker that produces probabilities of speculative + tokens according to some base model. Typically a vanilla vLLM + Worker. + spec_decode_sampler: A Torch module used to perform acceptance + sampling of the draft tokens in the verification step of + speculative decoding. Currently we support two different + types of sampler namely RejectionSampler and + TypicalAcceptanceSampler. 'spec_decode_sampler' is either an + instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. + disable_logprobs: If set to True, token log probabilities will + not be output in both the draft worker and the target worker. + If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. + metrics_collector: Helper class for collecting metrics; can be set + for testing purposes. + allow_zero_draft_token_step: whether to allow a step where the draft + model generates no draft token; should disallow when the tp of + draft model is larger than 1 (TODO: #5814) + """ + self.proposer_worker = proposer_worker + self.scorer_worker = scorer_worker + scorer_runner = getattr(self.scorer_worker, "model_runner", None) + self.generators = scorer_runner.get_generators( + ) if scorer_runner else None + self.disable_by_batch_size = disable_by_batch_size or float("inf") + self.spec_decode_sampler = spec_decode_sampler + self._allow_zero_draft_token_step = allow_zero_draft_token_step + self._metrics = AsyncMetricsCollector( + self.spec_decode_sampler + ) if metrics_collector is None else metrics_collector + # Tracks the sequence IDs that received a bonus token ID in + # their last forward pass. Needed only if KV cache is being + # used for token generation such as in the case of MultiStepWorker. + self._seq_with_bonus_token_in_last_step: Set[int] = set() + # Tracks the currently active request ids and the sequence IDs + # corresponding to them + self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) + # Tracks if the proposer worker uses the KV cache or not. + + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype + # Lazy initialization. + self.scorer: SpeculativeScorer + self.disable_mqa_scorer = disable_mqa_scorer + + # Hidden states from target model to pass to proposer + # in the subsequent step. + self.previous_hidden_states: Optional[HiddenStates] = None + self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats + + def init_device(self) -> None: + """Initialize both scorer and proposer models. + """ + # The scorer worker model is initialized first in case the proposer + # model has a smaller TP degree than the target worker. + self.scorer_worker.init_device() + self.proposer_worker.init_device() + + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + + self._metrics.init_tensors(self.rank, device_type=self.device) + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) + + scorer_cls: Type[SpeculativeScorer] + if self.disable_mqa_scorer: + scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) + + self._configure_model_sampler_for_spec_decode() + + def load_model(self, *args, **kwargs): + pass + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of sampling during verification. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True + self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of cache blocks to use. + + This is done by profiling the scorer model (which is typically the + larger of the two). Then the total memory which would be used by the + scorer cache is divided evenly between the proposer and scorer model KV, + such that the number of blocks is equal in both KV caches. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.scorer_worker.determine_num_available_blocks()) + + scorer_cache_block_size_bytes = ( + self.scorer_worker.get_cache_block_size_bytes()) + proposer_cache_block_size_bytes = ( + self.proposer_worker.get_cache_block_size_bytes()) + + new_num_gpu_blocks = split_num_cache_blocks_evenly( + scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, + num_gpu_blocks) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the cache engine of the scorer and proposer workers. + """ + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def get_model(self) -> nn.Module: + return self.scorer_worker.get_model() + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Perform speculative decoding on the input batch. + """ + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] + + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] + + self._track_finished_requests(execute_model_req) + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + all_prompt = True + atleast_one_prompt = False + all_zero_spec_tokens = True + for sgm in execute_model_req.seq_group_metadata_list: + all_prompt = all_prompt and sgm.is_prompt + atleast_one_prompt = atleast_one_prompt or sgm.is_prompt + all_zero_spec_tokens = all_zero_spec_tokens and ( + sgm.num_speculative_tokens == 0) + + if all_prompt and execute_model_req.seq_group_metadata_list: + assert num_lookahead_slots == 0, ( + "Prompt only runs should have num_lookahead_slots equal to 0. " + "This should never happen, please file a bug at " + "https://github.com/vllm-project/vllm/issues") + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch, or + # none of the requests in the batch have spec decoding enabled. + # In any of these cases, the proposer and scorer workers + # are called normally. + # We expect `num_speculative_tokens` to be None for prefills. + no_spec = (num_lookahead_slots == 0 or disable_all_speculation + or all_zero_spec_tokens) + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform them. + + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + no_spec=no_spec, + disable_all_speculation=disable_all_speculation, + # When both chunked prefill and speculative decoding are enabled + # it is possible that the same batch contains both prefill + # and decodes. If that happens in the scorer we run the batch + # as one single forward pass. However, in the proposer we + # run them as 2 different batches - one for prefill and + # the other for decodes. The variable indicates to the non-driver + # worker that there are prefills as part of the speculative batch + # and hence it needs to run an extra prefill forward pass. + run_spec_proposer_for_prefill=atleast_one_prompt, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + if no_spec: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + return (execute_model_req.running_queue_size + >= self.disable_by_batch_size) + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + + def _serialize_sampler_output_no_logprobs( + self, execute_model_req: ExecuteModelRequest, + sampler_output: SamplerOutput) -> List[SamplerOutput]: + """ + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. + All other parameters in `CompletionSequenceGroupOutput` related to log + probabilities are skipped. + + Args: + execute_model_req (ExecuteModelRequest): The model request that + was executed. + sampler_output (SamplerOutput): The output from the sampler with + only GPU tensors populated. + + Returns: + SamplerOutput: A new `SamplerOutput` instance containing a list of + `CompletionSequenceGroupOutput` objects with only token IDs + populated. + """ + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = [ + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ] + completion_seq_group_output_list: List[ + CompletionSequenceGroupOutput] = [] + output_index = 0 + # Make sure the non-terminal prefill chunks are still aligned with + # their own empty output. + for idx, seq_group_meta in enumerate( + execute_model_req.seq_group_metadata_list): + needs_prompt_logprobs = seq_output_prompt_logprobs[idx] + seq_id, seq_data = seq_data_entries[idx] + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + + # Some of these sequences may belong to non-terminal chunks, + # which may still have to report logprobs for prompts. + start = 1 if seq_data._num_computed_tokens == 0 \ + else seq_data._num_computed_tokens + end = (seq_data._num_computed_tokens + \ + seq_group_meta.token_chunk_size) + prompt_token_ids = prompt_token_ids[start:end] + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + else: + prompt_logprobs = None + + # Since we can get chunks here, we dont always have a sampled token + # (only on last chunk) but we still have to provide an output. + if not seq_group_meta.do_sample: + completion_seq_group_output_list.append( + CompletionSequenceGroupOutput( + samples=[], prompt_logprobs=prompt_logprobs)) + continue + + # Sequence with output. + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[output_index][0], + token_id_logprob_rank=-1, + token_id_logprob=0.0, + seq_id=seq_id, + topk_token_ids=[], + topk_logprobs=[], + prompt_logprobs=prompt_logprobs)) + output_index += 1 + + return [SamplerOutput(outputs=completion_seq_group_output_list)] + + @nvtx_range("spec_decode_worker._run_no_spec") + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. + """ + + sampler_output = self.scorer_worker.execute_model(execute_model_req) + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] + + # Store hidden states from target model execution, BxD. + hidden_states = sampler_output.hidden_states + if hidden_states is not None: + # Only decodes and prefill terminal chunks need a hidden state. + seq_group_meta_with_hidden = [ + sg for sg in execute_model_req.seq_group_metadata_list + if sg.do_sample + ] + if any(seq.is_prompt for seq in seq_group_meta_with_hidden): + # Drop hidden_states with no prediction (eg non-terminal chunks) + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] + if self.previous_hidden_states is None and len( + seq_group_meta_with_hidden): + self.previous_hidden_states = HiddenStates( + hidden_states, seq_group_meta_with_hidden) + elif self.previous_hidden_states and len( + seq_group_meta_with_hidden): + self.previous_hidden_states.update(hidden_states, + seq_group_meta_with_hidden) + + if not skip_proposer: + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) + + self.proposer_worker.execute_model(execute_model_req) + + sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( + execute_model_req=execute_model_req, sampler_output=sampler_output) + if self._disable_logprobs else + [sampler_output]) + + # Clear device tensors from sampler output. This reduces communication + # overhead when the engine runs in a different process than the workers. + sampler_output.sampled_token_probs = None + sampler_output.sampled_token_ids = None + sampler_output.logprobs = None + return sampler_output_to_return + + def _run_non_driver_rank(self) -> bool: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + + Returns True if there are remaining sequences to process. + """ + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] + + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. + if data["no_spec"]: + self.scorer_worker.execute_model() + + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + if not data["no_spec"]: + self.scorer_worker.execute_model() + if data["run_spec_proposer_for_prefill"]: + self.proposer_worker.execute_model() + + return True + + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") + def _run_speculative_decoding_step( + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: + """Execute a single step of speculative decoding. + + This invokes the proposer worker to get k speculative tokens for each + sequence, then scores each speculative token using the scoring worker. + + When `enable_chunked_prefill` is set, scorer will batch decodes and + prefills, while proposer will sync its KV-cache by running an extra + forward on prefills. + + Returns a list of SamplerOutput, each containing a single token per + sequence. + """ + # With prefill chunking, expect requests to have prompts first + # so that backend gets prefill|decode. + assert num_lookahead_slots == execute_model_req.num_lookahead_slots + + # Pass last hidden states from target model to proposer + execute_model_req.previous_hidden_states = self.previous_hidden_states + self.previous_hidden_states = None + + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) + + if not self._allow_zero_draft_token_step and proposals.no_proposals: + #TODO: Fix it #5814 + raise RuntimeError("Cannot handle cases where distributed draft " + "workers generate no tokens") + + execute_model_req.previous_hidden_states = None + + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len( + execute_model_req.seq_group_metadata_list, proposals.proposal_lens) + # With prefill chunking enabled, `non_spec_seqs` contains prefills too: + # discard decodes that have already been processed by proposer. + non_spec_indices = [ + idx for idx in non_spec_indices + if execute_model_req.seq_group_metadata_list[idx].is_prompt + ] + if len(non_spec_indices): + all_hidden_states = proposal_scores.hidden_states + if all_hidden_states is not None: + prefill_hidden_states = all_hidden_states[non_spec_indices] + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states(prefill_hidden_states) + # Sync proposer KV cache for prefills. + prefill_req = execute_model_req.clone(non_spec_seqs) + # TODO avoid sampling here? + self.proposer_worker.execute_model(prefill_req) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) + + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) + + return self._create_output_sampler_list( + execute_model_req.seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + prompt_logprobs=proposal_scores.prompt_logprobs + if not self._disable_logprobs else None, + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) + + @nvtx_range("spec_decode_worker._verify_tokens") + def _verify_tokens( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_scores: SpeculativeScores, + proposals: SpeculativeProposals, + max_proposal_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Determine which speculative tokens are accepted using the + probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. + """ + proposal_lens_list = proposals.proposal_lens.tolist() + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + original_indices = spec_indices + non_spec_indices + + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] + + # Get non-speculative sampled tokens from target model. + non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + + # Sampler arguments + sampler_extra_kwargs: Dict[str, Any] = {} + if self.generators and isinstance(self.spec_decode_sampler, + SpecDecodeStochasticBaseSampler): + sampler_extra_kwargs["seeded_seqs"] = { + idx: self.generators[sgm.request_id] + for idx, sgm in enumerate(seq_group_metadata_list) + if sgm.sampling_params.seed is not None + } + + accepted_token_ids = self.spec_decode_sampler( + target_with_bonus_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + **sampler_extra_kwargs, + ) + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + + 1).clone() + non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = torch.cat( + [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs + # Rearrange so that results are in the order of the original seq group + # metadata. + accepted_token_ids[original_indices] = accepted_token_ids.clone() + + # B x K+1 x D + hidden_states = proposal_scores.hidden_states + if hidden_states is not None: + # Only get terminal hidden states for next step + terminal_metadata = [ + sg for sg in seq_group_metadata_list if sg.do_sample + ] + + # Contract hidden states based on accepted tokens + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 + accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[accepted_index != + VLLM_INVALID_TOKEN_ID] + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, None].expand(-1, 1, + hs_size) # b x 1 x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + second_last_token_hidden_states) + return accepted_token_ids, logprobs + + def _create_output_sampler_list( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] + prompt_logprobs: Optional[ + torch.Tensor], # shape: [nprompt_tokens, vocab_size] + k: int, + stage_times: Tuple[float, float, float], + ) -> List[SamplerOutput]: + """Given the accepted token ids, create a list of SamplerOutput. + + The output is padded with -1 tokens such that each sequence has + the same number of outputs. + """ + batch_size, num_steps = accepted_token_ids.shape + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + if self._disable_logprobs: + # We are skipping the logprobs. Hence don't serialize the + # logprobs related tensors from the GPU. Instead create + # empty/dummy lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_dummy_logprob_lists( + batch_size, num_steps, + self.scorer_worker.model_config.max_logprobs) + else: + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + # Serialize all tensors into Python lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_logprob_lists_from_tensors( + target_logprobs_by_step, accepted_token_ids_by_step, + self.scorer_worker.model_config.max_logprobs) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. + seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( + seq_group_metadata_list) + + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize tensor to CPU Python list. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. + # Non-terminal prefill chunks will end up here as rows with just -1s + # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while + # terminal chunks will only have one generated token at time 0. + sampler_output_list: List[SamplerOutput] = [] + + # Prefills are not multi-step (return at most 1 token), in order to + # avoid padding or repetition to fit decodes, we separate them. + for i, sg in enumerate(seq_group_metadata_list): + if not sg.is_prompt: + # Requests are ordered as prefills|decodes=>no more prefills. + break + num_logprobs = num_logprobs_per_seq[i] + seq_kwargs = dict(token_id=-1, + token_id_logprob_rank=0, + token_id_logprob=-float('inf'), + topk_token_ids=[-1] * num_logprobs, + topk_logprobs=[-float('inf')] * num_logprobs, + seq_id=seq_ids[i]) + # Terminal chunk, has token. + if sg.do_sample: + seq_kwargs.update( + dict( + token_id=accepted_token_ids[i][0].item(), + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + 0][i], + token_id_logprob=accepted_token_id_logprobs_by_step[0] + [i], + topk_token_ids=topk_indices_by_step[0][i] + [:num_logprobs], + # output only so step is 0 + topk_logprobs=topk_logprobs_by_step[0][i] + [:num_logprobs], + )) + needs_plogs = (sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + plogs = None + if prompt_logprobs is not None: + # Even non-terminal prompt chunks can have logprobs here. + plogs = prompt_logprobs[i] + elif needs_plogs: + # Prompt logprobs are requested but `_disable_logprobs` is set. + seq_data = next(iter(sg.seq_data.values())) + # Get only the tokens in this chunk! + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_token_ids = prompt_token_ids[ + seq_data. + _num_computed_tokens:seq_data._num_computed_tokens + + sg.token_chunk_size] + + is_first_chunk = seq_data._num_computed_tokens == 0 + # There's no prob generated for the first token in a sequence. + if is_first_chunk: + prompt_token_ids = prompt_token_ids[1:] + plogs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + seq_kwargs.update(dict(prompt_logprobs=plogs)) + + sampler_output_list.append( + SamplerOutput( + outputs=[create_sequence_group_output( + **seq_kwargs)])) # type: ignore + + # Decodes, create one SamplerOutput per-step (at most K+1). + for step_index in range(num_steps): + if all(token_id == -1 for sg, token_id in zip( + seq_group_metadata_list, + accepted_token_ids_by_step[step_index]) + if not sg.is_prompt): + break + + step_output_token_ids: List[CompletionSequenceGroupOutput] = [] + for sequence_index in range(batch_size): + seq_meta = seq_group_metadata_list[sequence_index] + # Prompts already processed above. + if seq_meta.is_prompt: + continue + + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], + )) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + + # Populate the data structures needed to keep track of sequences with + # bonus tokens. + self._track_sequences_with_bonus_tokens(seq_ids, + request_ids_seq_ids_mapping, + accepted_token_ids_by_step) + maybe_rejsample_metrics = ( + self._metrics.maybe_collect_rejsample_metrics(k)) + if maybe_rejsample_metrics is not None: + sampler_output_list[ + 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + # First `n_prefills` entries will contain prefills SamplerOutput when + # chunked prefill is enabled, the rest is decodes in multi-step format. + return sampler_output_list + + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + + def _create_dummy_logprob_lists( + self, + batch_size: int, + num_steps: int, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four dummy lists representing token probabilities + and their ranks. + + This method initializes and returns: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + batch_size (int): The size of the batch. + num_steps (int): The number of steps in the sequence. + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing four dummy lists as described above. + """ + accepted_token_id_ranks_by_step = [[-1] * batch_size + for _ in range(num_steps)] + accepted_token_id_logprobs_by_step = [[0.0] * batch_size + for _ in range(num_steps)] + topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + topk_indices_by_step: List[List[List[Optional[int]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _create_logprob_lists_from_tensors( + self, + target_logprobs_by_step: torch.Tensor, + accepted_token_ids_by_step: torch.Tensor, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four lists representing token probabilities and + their ranks. + + This method initializes and returns four lists containing: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + target_logprobs_by_step (torch.Tensor): Tensor representing the + log probabilities of the target model, + shaped (num_steps, batch_size, vocab_size) + accepted_token_ids_by_step (torch.Tensor): Tensor representing + the accepted token_ids, shaped (num_steps, batch_size) + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing the lists as described above. + """ + # Serialize all tensors to CPU Python lists. + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step_tensor, + accepted_token_id_logprobs_by_step_tensor + ) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + # Get the top-k logprobs (which may or may not include the + # logprob of the accepted token). + (topk_logprobs_by_step_tensor, + topk_indices_by_step_tensor) = target_logprobs_by_step.topk( + k=num_top_k, + dim=-1, + ) + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step_tensor.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step_tensor.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist() + topk_indices_by_step = topk_indices_by_step_tensor.tolist() + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): + """ + Removes the finished requests and their associated sequence ids from + internal book keeping data structures. + """ + for finished_request in execute_model_req.finished_requests_ids: + for seq_id in self._request_id_seq_id_mapping[finished_request]: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + del self._request_id_seq_id_mapping[finished_request] + + def _track_sequences_with_bonus_tokens( + self, seq_ids: List[int], + request_ids_seq_ids_mapping: Dict[str, Set[int]], + accepted_token_ids_by_step: List[List[int]]): + """ + Updates the internal data structures which keep track of sequences + which have been assigned bonus tokens in their last forward pass. + """ + for seq_index, seq_id in enumerate(seq_ids): + last_token_id = accepted_token_ids_by_step[-1][seq_index] + if last_token_id == -1: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + else: + self._seq_with_bonus_token_in_last_step.add(seq_id) + for request_id, sequences in request_ids_seq_ids_mapping.items(): + self._request_id_seq_id_mapping[request_id].update(sequences) + + @cached_property + def _vocab_size(self) -> int: + """Get the vocab size of the model and make sure it's consistent between + draft and target workers. + """ + vocab_sizes = [ + worker.vocab_size + for worker in [self.proposer_worker, self.scorer_worker] + ] + assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) + return vocab_sizes[0] + + @property + def rank(self): + return self.scorer_worker.rank + + @property + def device(self): + return self.scorer_worker.device + + @property + def _driver_rank(self) -> int: + return 0 + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + + def start_profile(self): + if isinstance(self.scorer_worker, WorkerBase): + self.scorer_worker.start_profile() + + def stop_profile(self): + if isinstance(self.scorer_worker, WorkerBase): + self.scorer_worker.stop_profile() + + +def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, + proposer_cache_block_size_bytes: int, + total_num_gpu_blocks: int) -> int: + """Given total_num_gpu_blocks, the number of GPU blocks that could be + allocate to the target model, this function calculates how many blocks + should be given to the draft and target model. + + Note that usually the block size, in bytes, of each model is different, + as it's a function of number of KV/layer, number of heads, and hidden + dimension size. + + Since the target and draft models allocate the same number of blocks, we + simply calculate the number of blocks where if allocated by both models, + the total memory usage from KV cache is no larger than the number of + blocks allocatable by the target model alone. + """ + new_num_gpu_blocks = int( + total_num_gpu_blocks * scorer_cache_block_size_bytes / + (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) + + return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9c04680a6a7ab37196633eddc1a218876164a18b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from contextlib import contextmanager +from typing import Dict, List, Optional, Sequence, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) + +SeqId = int + + +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[ + torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, + ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor + > expanded_selected_logprobs).sum(-1).add_(1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_logprobs_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_id: Logprob( + logprob=topk_logprob if topk_logprob is not None else 0.0, + rank=topk_index + 1, + ) + for topk_index, (topk_token_id, topk_logprob) \ + in enumerate(zip(topk_token_ids, topk_logprobs)) \ + if topk_token_id is not None + }) + + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + + return CompletionSequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + prompt_logprobs=prompt_logprobs, + ) + + +def split_batch_by_proposal_len( + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_lens: List[int], +) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ + List[SequenceGroupMetadata], List[int]]]: + """Utility function that splits a batch based on whether the proposal len is + zero or not. We should remove this once vLLM supports per-sequence proposal + lens in a batch. + """ + + nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + for i, (seq_group, proposal_len) in enumerate( + zip(seq_group_metadata_list, proposal_lens)): + seq_groups, indices = nonzero_lists if proposal_len else zero_lists + seq_groups.append(seq_group) + indices.append(i) + return nonzero_lists, zero_lists + + +def sampler_output_to_torch( + sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Utility function which converts a list of SamplerOutput to tensors. + + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + + Returns: + sampled_token_ids: torch.Tensor + shape: [batch_size, len(sampler_output_list)] + + sampled_token_probs: torch.Tensor + shape: [batch_size, len(sampler_output_list), vocab_size] + """ + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_probs = torch.stack( + [ + sampler_output.sampled_token_probs + for sampler_output in sampler_output_list + ], + dim=0, + ) + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + # shape: [batch_size, num_sampler_output] + sampled_token_ids = torch.stack( + [ + sampler_output.sampled_token_ids.flatten() + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) + sampled_token_ids = sampled_token_ids.transpose(0, 1) + + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) + + +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) + + +@contextmanager +def nvtx_range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + If running with cuda graphs, you must enable nsys cuda graph profiling. + + Arguments: + msg (string): message to associate with the range + """ + if current_platform.is_cuda_alike(): + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + else: + yield + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000