diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..96818507d589fca1f4d602df8a25111d82a4e5d2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py
@@ -0,0 +1,169 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+NOTE: This API server is used only for demonstrating usage of AsyncEngine
+and simple performance benchmarks. It is not intended for production use.
+For production use, we recommend using our OpenAI compatible server.
+We are also not going to accept PRs modifying this file, please
+change `vllm/entrypoints/openai/api_server.py` instead.
+"""
+import asyncio
+import json
+import ssl
+from argparse import Namespace
+from typing import Any, AsyncGenerator, Optional
+
+from fastapi import FastAPI, Request
+from fastapi.responses import JSONResponse, Response, StreamingResponse
+
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.entrypoints.launcher import serve_http
+from vllm.entrypoints.utils import with_cancellation
+from vllm.logger import init_logger
+from vllm.sampling_params import SamplingParams
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
+from vllm.version import __version__ as VLLM_VERSION
+
+logger = init_logger("vllm.entrypoints.api_server")
+
+TIMEOUT_KEEP_ALIVE = 5 # seconds.
+app = FastAPI()
+engine = None
+
+
+@app.get("/health")
+async def health() -> Response:
+ """Health check."""
+ return Response(status_code=200)
+
+
+@app.post("/generate")
+async def generate(request: Request) -> Response:
+ """Generate completion for the request.
+
+ The request should be a JSON object with the following fields:
+ - prompt: the prompt to use for the generation.
+ - stream: whether to stream the results or not.
+ - other fields: the sampling parameters (See `SamplingParams` for details).
+ """
+ request_dict = await request.json()
+ return await _generate(request_dict, raw_request=request)
+
+
+@with_cancellation
+async def _generate(request_dict: dict, raw_request: Request) -> Response:
+ prompt = request_dict.pop("prompt")
+ stream = request_dict.pop("stream", False)
+ sampling_params = SamplingParams(**request_dict)
+ request_id = random_uuid()
+
+ assert engine is not None
+ results_generator = engine.generate(prompt, sampling_params, request_id)
+
+ # Streaming case
+ async def stream_results() -> AsyncGenerator[bytes, None]:
+ async for request_output in results_generator:
+ prompt = request_output.prompt
+ assert prompt is not None
+ text_outputs = [
+ prompt + output.text for output in request_output.outputs
+ ]
+ ret = {"text": text_outputs}
+ yield (json.dumps(ret) + "\n").encode("utf-8")
+
+ if stream:
+ return StreamingResponse(stream_results())
+
+ # Non-streaming case
+ final_output = None
+ try:
+ async for request_output in results_generator:
+ final_output = request_output
+ except asyncio.CancelledError:
+ return Response(status_code=499)
+
+ assert final_output is not None
+ prompt = final_output.prompt
+ assert prompt is not None
+ text_outputs = [prompt + output.text for output in final_output.outputs]
+ ret = {"text": text_outputs}
+ return JSONResponse(ret)
+
+
+def build_app(args: Namespace) -> FastAPI:
+ global app
+
+ app.root_path = args.root_path
+ return app
+
+
+async def init_app(
+ args: Namespace,
+ llm_engine: Optional[AsyncLLMEngine] = None,
+) -> FastAPI:
+ app = build_app(args)
+
+ global engine
+
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ engine = (llm_engine
+ if llm_engine is not None else AsyncLLMEngine.from_engine_args(
+ engine_args, usage_context=UsageContext.API_SERVER))
+
+ return app
+
+
+async def run_server(args: Namespace,
+ llm_engine: Optional[AsyncLLMEngine] = None,
+ **uvicorn_kwargs: Any) -> None:
+ logger.info("vLLM API server version %s", VLLM_VERSION)
+ logger.info("args: %s", args)
+
+ set_ulimit()
+
+ app = await init_app(args, llm_engine)
+ assert engine is not None
+
+ shutdown_task = await serve_http(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level=args.log_level,
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+ ssl_keyfile=args.ssl_keyfile,
+ ssl_certfile=args.ssl_certfile,
+ ssl_ca_certs=args.ssl_ca_certs,
+ ssl_cert_reqs=args.ssl_cert_reqs,
+ **uvicorn_kwargs,
+ )
+
+ await shutdown_task
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser()
+ parser.add_argument("--host", type=str, default=None)
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--ssl-keyfile", type=str, default=None)
+ parser.add_argument("--ssl-certfile", type=str, default=None)
+ parser.add_argument("--ssl-ca-certs",
+ type=str,
+ default=None,
+ help="The CA certificates file")
+ parser.add_argument(
+ "--ssl-cert-reqs",
+ type=int,
+ default=int(ssl.CERT_NONE),
+ help="Whether client certificate is required (see stdlib ssl module's)"
+ )
+ parser.add_argument(
+ "--root-path",
+ type=str,
+ default=None,
+ help="FastAPI root_path when app is behind a path based routing proxy")
+ parser.add_argument("--log-level", type=str, default="debug")
+ parser = AsyncEngineArgs.add_cli_args(parser)
+ args = parser.parse_args()
+
+ asyncio.run(run_server(args))
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f04902ae1c7678c736bcb49450339eeafcbb6b75
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py
@@ -0,0 +1,1007 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import codecs
+import json
+from abc import ABC, abstractmethod
+from collections import defaultdict, deque
+from functools import cache, lru_cache, partial
+from pathlib import Path
+from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
+ Literal, Optional, Tuple, TypeVar, Union, cast)
+
+import jinja2.nodes
+import transformers.utils.chat_template_utils as hf_chat_utils
+# yapf conflicts with isort for this block
+# yapf: disable
+from openai.types.chat import (ChatCompletionAssistantMessageParam,
+ ChatCompletionContentPartImageParam,
+ ChatCompletionContentPartInputAudioParam)
+from openai.types.chat import (
+ ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
+from openai.types.chat import (ChatCompletionContentPartRefusalParam,
+ ChatCompletionContentPartTextParam)
+from openai.types.chat import (
+ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
+from openai.types.chat import (ChatCompletionMessageToolCallParam,
+ ChatCompletionToolMessageParam)
+from openai.types.chat.chat_completion_content_part_input_audio_param import (
+ InputAudio)
+# yapf: enable
+# pydantic needs the TypedDict from typing_extensions
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+from typing_extensions import Required, TypeAlias, TypedDict
+
+from vllm.config import ModelConfig
+from vllm.logger import init_logger
+from vllm.multimodal import MultiModalDataDict
+from vllm.multimodal.utils import MediaConnector
+from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
+
+logger = init_logger(__name__)
+
+
+class AudioURL(TypedDict, total=False):
+ url: Required[str]
+ """
+ Either a URL of the audio or a data URL with base64 encoded audio data.
+ """
+
+
+class ChatCompletionContentPartAudioParam(TypedDict, total=False):
+ audio_url: Required[AudioURL]
+
+ type: Required[Literal["audio_url"]]
+ """The type of the content part."""
+
+
+class VideoURL(TypedDict, total=False):
+ url: Required[str]
+ """
+ Either a URL of the video or a data URL with base64 encoded video data.
+ """
+
+
+class ChatCompletionContentPartVideoParam(TypedDict, total=False):
+ video_url: Required[VideoURL]
+
+ type: Required[Literal["video_url"]]
+ """The type of the content part."""
+
+
+class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
+ """A simpler version of the param that only accepts a plain image_url.
+ This is supported by OpenAI API, although it is not documented.
+
+ Example:
+ {
+ "image_url": "https://example.com/image.jpg"
+ }
+ """
+ image_url: Required[str]
+
+
+class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
+ """A simpler version of the param that only accepts a plain audio_url.
+
+ Example:
+ {
+ "audio_url": "https://example.com/audio.mp3"
+ }
+ """
+ audio_url: Required[str]
+
+
+class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
+ """A simpler version of the param that only accepts a plain audio_url.
+
+ Example:
+ {
+ "video_url": "https://example.com/video.mp4"
+ }
+ """
+ video_url: Required[str]
+
+
+ChatCompletionContentPartParam: TypeAlias = Union[
+ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
+ ChatCompletionContentPartInputAudioParam,
+ ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
+ CustomChatCompletionContentSimpleImageParam,
+ CustomChatCompletionContentSimpleAudioParam,
+ CustomChatCompletionContentSimpleVideoParam, str]
+
+
+class CustomChatCompletionMessageParam(TypedDict, total=False):
+ """Enables custom roles in the Chat Completion API."""
+ role: Required[str]
+ """The role of the message's author."""
+
+ content: Union[str, List[ChatCompletionContentPartParam]]
+ """The contents of the message."""
+
+ name: str
+ """An optional name for the participant.
+
+ Provides the model information to differentiate between participants of the
+ same role.
+ """
+
+ tool_call_id: Optional[str]
+ """Tool call that this message is responding to."""
+
+ tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
+ """The tool calls generated by the model, such as function calls."""
+
+
+ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
+ CustomChatCompletionMessageParam]
+
+
+# TODO: Make fields ReadOnly once mypy supports it
+class ConversationMessage(TypedDict, total=False):
+ role: Required[str]
+ """The role of the message's author."""
+
+ content: Union[Optional[str], List[Dict[str, str]]]
+ """The contents of the message"""
+
+ tool_call_id: Optional[str]
+ """Tool call that this message is responding to."""
+
+ name: Optional[str]
+ """The name of the function to call"""
+
+ tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
+ """The tool calls generated by the model, such as function calls."""
+
+
+# Passed in by user
+ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
+
+# Used internally
+_ChatTemplateContentFormat = Literal["string", "openai"]
+
+
+def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
+ if isinstance(node, jinja2.nodes.Name):
+ return node.ctx == "load" and node.name == varname
+
+ return False
+
+
+def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
+ if isinstance(node, jinja2.nodes.Getitem):
+ return (_is_var_access(node.node, varname)
+ and isinstance(node.arg, jinja2.nodes.Const)
+ and node.arg.value == key)
+
+ if isinstance(node, jinja2.nodes.Getattr):
+ return _is_var_access(node.node, varname) and node.attr == key
+
+ return False
+
+
+def _is_var_or_elems_access(
+ node: jinja2.nodes.Node,
+ varname: str,
+ key: Optional[str] = None,
+) -> bool:
+ if isinstance(node, jinja2.nodes.Filter):
+ return (node.node is not None
+ and _is_var_or_elems_access(node.node, varname, key))
+ if isinstance(node, jinja2.nodes.Test):
+ return _is_var_or_elems_access(node.node, varname, key)
+
+ if (isinstance(node, jinja2.nodes.Getitem)
+ and isinstance(node.arg, jinja2.nodes.Slice)):
+ return _is_var_or_elems_access(node.node, varname, key)
+
+ # yapf: disable
+ return (
+ _is_attr_access(node, varname, key) if key
+ else _is_var_access(node, varname)
+ ) # yapf: enable
+
+
+def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
+ # Global variable that is implicitly defined at the root
+ yield root, varname
+
+ # Iterative BFS
+ related_varnames = deque([varname])
+ while related_varnames:
+ related_varname = related_varnames.popleft()
+
+ for assign_ast in root.find_all(jinja2.nodes.Assign):
+ lhs = assign_ast.target
+ rhs = assign_ast.node
+
+ if _is_var_or_elems_access(rhs, related_varname):
+ assert isinstance(lhs, jinja2.nodes.Name)
+ yield assign_ast, lhs.name
+
+ # Avoid infinite looping for self-assignment
+ if lhs.name != related_varname:
+ related_varnames.append(lhs.name)
+
+
+# NOTE: The proper way to handle this is to build a CFG so that we can handle
+# the scope in which each variable is defined, but that is too complicated
+def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
+ messages_varnames = [
+ varname
+ for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
+ ]
+
+ # Search for {%- for message in messages -%} loops
+ for loop_ast in root.find_all(jinja2.nodes.For):
+ loop_iter = loop_ast.iter
+ loop_target = loop_ast.target
+
+ for varname in messages_varnames:
+ if _is_var_or_elems_access(loop_iter, varname):
+ assert isinstance(loop_target, jinja2.nodes.Name)
+ yield loop_ast, loop_target.name
+ break
+
+
+def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
+ message_varnames = [
+ varname for _, varname in _iter_nodes_assign_messages_item(root)
+ ]
+
+ # Search for {%- for content in message['content'] -%} loops
+ for loop_ast in root.find_all(jinja2.nodes.For):
+ loop_iter = loop_ast.iter
+ loop_target = loop_ast.target
+
+ for varname in message_varnames:
+ if _is_var_or_elems_access(loop_iter, varname, "content"):
+ assert isinstance(loop_target, jinja2.nodes.Name)
+ yield loop_ast, loop_target.name
+ break
+
+
+def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
+ try:
+ jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
+ return jinja_compiled.environment.parse(chat_template)
+ except Exception:
+ logger.exception("Error when compiling Jinja template")
+ return None
+
+
+def _detect_content_format(
+ chat_template: str,
+ *,
+ default: _ChatTemplateContentFormat,
+) -> _ChatTemplateContentFormat:
+ jinja_ast = _try_extract_ast(chat_template)
+ if jinja_ast is None:
+ return default
+
+ try:
+ next(_iter_nodes_assign_content_item(jinja_ast))
+ except StopIteration:
+ return "string"
+ except Exception:
+ logger.exception("Error when parsing AST of Jinja template")
+ return default
+ else:
+ return "openai"
+
+
+def _resolve_chat_template_content_format(
+ chat_template: Optional[str],
+ given_format: ChatTemplateContentFormatOption,
+ tokenizer: AnyTokenizer,
+) -> _ChatTemplateContentFormat:
+ if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
+ tokenizer_chat_template = tokenizer.chat_template
+ else:
+ tokenizer_chat_template = None
+
+ jinja_text: Optional[str]
+ if isinstance(tokenizer_chat_template, str) and chat_template is None:
+ jinja_text = tokenizer_chat_template
+ elif (isinstance(tokenizer_chat_template, dict)
+ and chat_template in tokenizer_chat_template):
+ jinja_text = tokenizer_chat_template[chat_template]
+ else:
+ jinja_text = load_chat_template(chat_template, is_literal=True)
+
+ detected_format = ("string" if jinja_text is None else
+ _detect_content_format(jinja_text, default="string"))
+
+ return detected_format if given_format == "auto" else given_format
+
+
+@lru_cache
+def resolve_chat_template_content_format(
+ chat_template: Optional[str],
+ given_format: ChatTemplateContentFormatOption,
+ tokenizer: AnyTokenizer,
+) -> _ChatTemplateContentFormat:
+ detected_format = _resolve_chat_template_content_format(
+ chat_template,
+ given_format,
+ tokenizer,
+ )
+
+ logger.info(
+ "Detected the chat template content format to be '%s'. "
+ "You can set `--chat-template-content-format` to override this.",
+ detected_format,
+ )
+
+ if given_format != "auto" and given_format != detected_format:
+ logger.warning(
+ "You specified `--chat-template-content-format %s` "
+ "which is different from the detected format '%s'. "
+ "If our automatic detection is incorrect, please consider "
+ "opening a GitHub issue so that we can improve it: "
+ "https://github.com/vllm-project/vllm/issues/new/choose",
+ given_format,
+ detected_format,
+ )
+
+ return detected_format
+
+
+ModalityStr = Literal["image", "audio", "video"]
+_T = TypeVar("_T")
+
+
+class BaseMultiModalItemTracker(ABC, Generic[_T]):
+ """
+ Tracks multi-modal items in a given request and ensures that the number
+ of multi-modal items in a given request does not exceed the configured
+ maximum per prompt.
+ """
+
+ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
+ super().__init__()
+
+ self._model_config = model_config
+ self._tokenizer = tokenizer
+ self._allowed_items = (model_config.multimodal_config.limit_per_prompt
+ if model_config.multimodal_config else {})
+
+ self._items_by_modality = defaultdict[str, list[_T]](list)
+
+ @property
+ def model_config(self) -> ModelConfig:
+ return self._model_config
+
+ @property
+ def allowed_local_media_path(self):
+ return self._model_config.allowed_local_media_path
+
+ @staticmethod
+ @cache
+ def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
+ return tokenizer.decode(token_index)
+
+ def _placeholder_str(self, modality: ModalityStr,
+ current_count: int) -> Optional[str]:
+ # TODO: Let user specify how to insert image tokens into prompt
+ # (similar to chat template)
+ hf_config = self._model_config.hf_config
+ model_type = hf_config.model_type
+
+ if modality == "image":
+ if model_type == "phi3_v":
+ # Workaround since this token is not defined in the tokenizer
+ return f"<|image_{current_count}|>"
+ if model_type in ("minicpmo", "minicpmv"):
+ return "(./)"
+ if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
+ "pixtral"):
+ # These models do not use image tokens in the prompt
+ return None
+ if model_type == "qwen":
+ return f"Picture {current_count}:
"
+ if model_type.startswith("llava"):
+ return self._cached_token_str(self._tokenizer,
+ hf_config.image_token_index)
+ if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
+ "NVLM_D", "h2ovl_chat"):
+ return ""
+ if model_type == "mllama":
+ return "<|image|>"
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
+ return "<|vision_start|><|image_pad|><|vision_end|>"
+ if model_type == "molmo":
+ return ""
+ if model_type == "idefics3":
+ return ""
+ if model_type == "aria":
+ return "<|fim_prefix|><|img|><|fim_suffix|>"
+
+ raise TypeError(f"Unknown {modality} model type: {model_type}")
+ elif modality == "audio":
+ if model_type == "ultravox":
+ return "<|audio|>"
+ if model_type == "qwen2_audio":
+ return (f"Audio {current_count}: "
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>")
+ if model_type == "minicpmo":
+ return "()"
+ raise TypeError(f"Unknown model type: {model_type}")
+ elif modality == "video":
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
+ return "<|vision_start|><|video_pad|><|vision_end|>"
+ if model_type in ("minicpmo", "minicpmv"):
+ return "()"
+ if model_type.startswith("llava"):
+ return self._cached_token_str(self._tokenizer,
+ hf_config.video_token_index)
+ raise TypeError(f"Unknown {modality} model type: {model_type}")
+ else:
+ raise TypeError(f"Unknown modality: {modality}")
+
+ def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
+ """
+ Add a multi-modal item to the current prompt and returns the
+ placeholder string to use, if any.
+ """
+ allowed_count = self._allowed_items.get(modality, 1)
+ current_count = len(self._items_by_modality[modality]) + 1
+ if current_count > allowed_count:
+ raise ValueError(
+ f"At most {allowed_count} {modality}(s) may be provided in "
+ "one request.")
+
+ self._items_by_modality[modality].append(item)
+
+ return self._placeholder_str(modality, current_count)
+
+ @abstractmethod
+ def create_parser(self) -> "BaseMultiModalContentParser":
+ raise NotImplementedError
+
+
+class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
+
+ def all_mm_data(self) -> Optional[MultiModalDataDict]:
+ if self._items_by_modality:
+ return dict(self._items_by_modality)
+
+ return None
+
+ def create_parser(self) -> "BaseMultiModalContentParser":
+ return MultiModalContentParser(self)
+
+
+class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
+
+ async def all_mm_data(self) -> Optional[MultiModalDataDict]:
+ if self._items_by_modality:
+ return {
+ modality: await asyncio.gather(*items)
+ for modality, items in self._items_by_modality.items()
+ }
+
+ return None
+
+ def create_parser(self) -> "BaseMultiModalContentParser":
+ return AsyncMultiModalContentParser(self)
+
+
+class BaseMultiModalContentParser(ABC):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ # multimodal placeholder_string : count
+ self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
+
+ def _add_placeholder(self, placeholder: Optional[str]):
+ if placeholder:
+ self._placeholder_counts[placeholder] += 1
+
+ def mm_placeholder_counts(self) -> Dict[str, int]:
+ return dict(self._placeholder_counts)
+
+ @abstractmethod
+ def parse_image(self, image_url: str) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def parse_audio(self, audio_url: str) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def parse_video(self, video_url: str) -> None:
+ raise NotImplementedError
+
+
+class MultiModalContentParser(BaseMultiModalContentParser):
+
+ def __init__(self, tracker: MultiModalItemTracker) -> None:
+ super().__init__()
+
+ self._tracker = tracker
+
+ self._connector = MediaConnector(
+ allowed_local_media_path=tracker.allowed_local_media_path,
+ )
+
+ def parse_image(self, image_url: str) -> None:
+ image = self._connector.fetch_image(image_url)
+
+ placeholder = self._tracker.add("image", image)
+ self._add_placeholder(placeholder)
+
+ def parse_audio(self, audio_url: str) -> None:
+ audio = self._connector.fetch_audio(audio_url)
+
+ placeholder = self._tracker.add("audio", audio)
+ self._add_placeholder(placeholder)
+
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
+ audio_data = input_audio.get("data", "")
+ audio_format = input_audio.get("format", "")
+ audio_url = f"data:audio/{audio_format};base64,{audio_data}"
+
+ return self.parse_audio(audio_url)
+
+ def parse_video(self, video_url: str) -> None:
+ video = self._connector.fetch_video(video_url)
+
+ placeholder = self._tracker.add("video", video)
+ self._add_placeholder(placeholder)
+
+
+class AsyncMultiModalContentParser(BaseMultiModalContentParser):
+
+ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
+ super().__init__()
+
+ self._tracker = tracker
+ self._connector = MediaConnector(
+ allowed_local_media_path=tracker.allowed_local_media_path,
+ )
+
+ def parse_image(self, image_url: str) -> None:
+ image_coro = self._connector.fetch_image_async(image_url)
+
+ placeholder = self._tracker.add("image", image_coro)
+ self._add_placeholder(placeholder)
+
+ def parse_audio(self, audio_url: str) -> None:
+ audio_coro = self._connector.fetch_audio_async(audio_url)
+
+ placeholder = self._tracker.add("audio", audio_coro)
+ self._add_placeholder(placeholder)
+
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
+ audio_data = input_audio.get("data", "")
+ audio_format = input_audio.get("format", "")
+ audio_url = f"data:audio/{audio_format};base64,{audio_data}"
+
+ return self.parse_audio(audio_url)
+
+ def parse_video(self, video_url: str) -> None:
+ video = self._connector.fetch_video_async(video_url)
+
+ placeholder = self._tracker.add("video", video)
+ self._add_placeholder(placeholder)
+
+
+def validate_chat_template(chat_template: Optional[Union[Path, str]]):
+ """Raises if the provided chat template appears invalid."""
+ if chat_template is None:
+ return
+
+ elif isinstance(chat_template, Path) and not chat_template.exists():
+ raise FileNotFoundError(
+ "the supplied chat template path doesn't exist")
+
+ elif isinstance(chat_template, str):
+ JINJA_CHARS = "{}\n"
+ if not any(c in chat_template
+ for c in JINJA_CHARS) and not Path(chat_template).exists():
+ raise ValueError(
+ f"The supplied chat template string ({chat_template}) "
+ f"appears path-like, but doesn't exist!")
+
+ else:
+ raise TypeError(
+ f"{type(chat_template)} is not a valid chat template type")
+
+
+def load_chat_template(
+ chat_template: Optional[Union[Path, str]],
+ *,
+ is_literal: bool = False,
+) -> Optional[str]:
+ if chat_template is None:
+ return None
+
+ if is_literal:
+ if isinstance(chat_template, Path):
+ raise TypeError("chat_template is expected to be read directly "
+ "from its value")
+
+ return codecs.decode(chat_template, "unicode_escape")
+
+ try:
+ with open(chat_template) as f:
+ return f.read()
+ except OSError as e:
+ if isinstance(chat_template, Path):
+ raise
+
+ JINJA_CHARS = "{}\n"
+ if not any(c in chat_template for c in JINJA_CHARS):
+ msg = (f"The supplied chat template ({chat_template}) "
+ f"looks like a file path, but it failed to be "
+ f"opened. Reason: {e}")
+ raise ValueError(msg) from e
+
+ # If opening a file fails, set chat template to be args to
+ # ensure we decode so our escape are interpreted correctly
+ return load_chat_template(chat_template, is_literal=True)
+
+
+# TODO: Let user specify how to insert multimodal tokens into prompt
+# (similar to chat template)
+def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
+ text_prompt: str) -> str:
+ """Combine multimodal prompts for a multimodal language model."""
+
+ # Look through the text prompt to check for missing placeholders
+ missing_placeholders: List[str] = []
+ for placeholder in placeholder_counts:
+
+ # For any existing placeholder in the text prompt, we leave it as is
+ placeholder_counts[placeholder] -= text_prompt.count(placeholder)
+
+ if placeholder_counts[placeholder] < 0:
+ raise ValueError(
+ f"Found more '{placeholder}' placeholders in input prompt than "
+ "actual multimodal data items.")
+
+ missing_placeholders.extend([placeholder] *
+ placeholder_counts[placeholder])
+
+ # NOTE: For now we always add missing placeholders at the front of
+ # the prompt. This may change to be customizable in the future.
+ return "\n".join(missing_placeholders + [text_prompt])
+
+
+# No need to validate using Pydantic again
+_TextParser = partial(cast, ChatCompletionContentPartTextParam)
+_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
+_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
+_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
+_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
+_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
+
+_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
+
+# Define a mapping from part types to their corresponding parsing functions.
+MM_PARSER_MAP: Dict[
+ str,
+ Callable[[ChatCompletionContentPartParam], _ContentPart],
+] = {
+ "text":
+ lambda part: _TextParser(part).get("text", ""),
+ "image_url":
+ lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
+ "audio_url":
+ lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
+ "input_audio":
+ lambda part: _InputAudioParser(part).get("input_audio", {}),
+ "refusal":
+ lambda part: _RefusalParser(part).get("refusal", ""),
+ "video_url":
+ lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
+}
+
+
+def _parse_chat_message_content_mm_part(
+ part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
+ """
+ Parses a given multi-modal content part based on its type.
+
+ Args:
+ part: A dict containing the content part, with a potential 'type' field.
+
+ Returns:
+ A tuple (part_type, content) where:
+ - part_type: Type of the part (e.g., 'text', 'image_url').
+ - content: Parsed content (e.g., text, image URL).
+
+ Raises:
+ ValueError: If the 'type' field is missing and no direct URL is found.
+ """
+ assert isinstance(
+ part, dict) # This is needed to avoid mypy errors: part.get() from str
+ part_type = part.get("type", None)
+
+ if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
+ content = MM_PARSER_MAP[part_type](part)
+
+ # Special case for 'image_url.detail'
+ # We only support 'auto', which is the default
+ if part_type == "image_url" and part.get("detail", "auto") != "auto":
+ logger.warning("'image_url.detail' is currently not supported "
+ "and will be ignored.")
+
+ return part_type, content
+
+ # Handle missing 'type' but provided direct URL fields.
+ # 'type' is required field by pydantic
+ if part_type is None:
+ if part.get("image_url") is not None:
+ image_params = cast(CustomChatCompletionContentSimpleImageParam,
+ part)
+ return "image_url", image_params.get("image_url", "")
+ if part.get("audio_url") is not None:
+ audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
+ part)
+ return "audio_url", audio_params.get("audio_url", "")
+ if part.get("input_audio") is not None:
+ input_audio_params = cast(Dict[str, str], part)
+ return "input_audio", input_audio_params
+ if part.get("video_url") is not None:
+ video_params = cast(CustomChatCompletionContentSimpleVideoParam,
+ part)
+ return "video_url", video_params.get("video_url", "")
+ # Raise an error if no 'type' or direct URL is found.
+ raise ValueError("Missing 'type' field in multimodal part.")
+
+ if not isinstance(part_type, str):
+ raise ValueError("Invalid 'type' field in multimodal part.")
+ return part_type, "unknown part_type content"
+
+
+VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
+ "audio_url", "input_audio", "video_url")
+
+
+def _parse_chat_message_content_parts(
+ role: str,
+ parts: Iterable[ChatCompletionContentPartParam],
+ mm_tracker: BaseMultiModalItemTracker,
+ *,
+ wrap_dicts: bool,
+) -> List[ConversationMessage]:
+ content = list[_ContentPart]()
+
+ mm_parser = mm_tracker.create_parser()
+
+ for part in parts:
+ parse_res = _parse_chat_message_content_part(
+ part,
+ mm_parser,
+ wrap_dicts=wrap_dicts,
+ )
+ if parse_res:
+ content.append(parse_res)
+
+ if wrap_dicts:
+ # Parsing wraps images and texts as interleaved dictionaries
+ return [ConversationMessage(role=role,
+ content=content)] # type: ignore
+ texts = cast(List[str], content)
+ text_prompt = "\n".join(texts)
+ mm_placeholder_counts = mm_parser.mm_placeholder_counts()
+ if mm_placeholder_counts:
+ text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
+ text_prompt)
+ return [ConversationMessage(role=role, content=text_prompt)]
+
+
+def _parse_chat_message_content_part(
+ part: ChatCompletionContentPartParam,
+ mm_parser: BaseMultiModalContentParser,
+ *,
+ wrap_dicts: bool,
+) -> Optional[_ContentPart]:
+ """Parses a single part of a conversation. If wrap_dicts is True,
+ structured dictionary pieces for texts and images will be
+ wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
+ {"type": "image"}, respectively. Otherwise multimodal data will be
+ handled by mm_parser, and texts will be returned as strings to be joined
+ with multimodal placeholders.
+ """
+ if isinstance(part, str): # Handle plain text parts
+ return part
+
+ # Handle structured dictionary parts
+ part_type, content = _parse_chat_message_content_mm_part(part)
+
+ # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
+ # content is empty, log a warning and skip
+ if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
+ logger.warning(
+ "Skipping multimodal part (type: '%s')"
+ "with empty / unparsable content.", part_type)
+ return None
+
+ if part_type in ("text", "refusal"):
+ str_content = cast(str, content)
+ if wrap_dicts:
+ return {'type': 'text', 'text': str_content}
+ else:
+ return str_content
+
+ if part_type == "image_url":
+ str_content = cast(str, content)
+ mm_parser.parse_image(str_content)
+ return {'type': 'image'} if wrap_dicts else None
+
+ if part_type == "audio_url":
+ str_content = cast(str, content)
+ mm_parser.parse_audio(str_content)
+ return {'type': 'audio'} if wrap_dicts else None
+
+ if part_type == "input_audio":
+ dict_content = cast(InputAudio, content)
+ mm_parser.parse_input_audio(dict_content)
+ return {'type': 'audio'} if wrap_dicts else None
+
+ if part_type == "video_url":
+ str_content = cast(str, content)
+ mm_parser.parse_video(str_content)
+ return {'type': 'video'} if wrap_dicts else None
+
+ raise NotImplementedError(f"Unknown part type: {part_type}")
+
+
+# No need to validate using Pydantic again
+_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
+_ToolParser = partial(cast, ChatCompletionToolMessageParam)
+
+
+def _parse_chat_message_content(
+ message: ChatCompletionMessageParam,
+ mm_tracker: BaseMultiModalItemTracker,
+ content_format: _ChatTemplateContentFormat,
+) -> List[ConversationMessage]:
+ role = message["role"]
+ content = message.get("content")
+
+ if content is None:
+ content = []
+ elif isinstance(content, str):
+ content = [
+ ChatCompletionContentPartTextParam(type="text", text=content)
+ ]
+ result = _parse_chat_message_content_parts(
+ role,
+ content, # type: ignore
+ mm_tracker,
+ wrap_dicts=(content_format == "openai"),
+ )
+
+ for result_msg in result:
+ if role == 'assistant':
+ parsed_msg = _AssistantParser(message)
+
+ if "tool_calls" in parsed_msg:
+ result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
+ elif role == "tool":
+ parsed_msg = _ToolParser(message)
+ if "tool_call_id" in parsed_msg:
+ result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
+
+ if "name" in message and isinstance(message["name"], str):
+ result_msg["name"] = message["name"]
+
+ return result
+
+
+def _postprocess_messages(messages: List[ConversationMessage]) -> None:
+ # per the Transformers docs & maintainers, tool call arguments in
+ # assistant-role messages with tool_calls need to be dicts not JSON str -
+ # this is how tool-use chat templates will expect them moving forwards
+ # so, for messages that have tool_calls, parse the string (which we get
+ # from openAI format) to dict
+ for message in messages:
+ if (message["role"] == "assistant" and "tool_calls" in message
+ and isinstance(message["tool_calls"], list)):
+
+ for item in message["tool_calls"]:
+ item["function"]["arguments"] = json.loads(
+ item["function"]["arguments"])
+
+
+def parse_chat_messages(
+ messages: List[ChatCompletionMessageParam],
+ model_config: ModelConfig,
+ tokenizer: AnyTokenizer,
+ content_format: _ChatTemplateContentFormat,
+) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
+ conversation: List[ConversationMessage] = []
+ mm_tracker = MultiModalItemTracker(model_config, tokenizer)
+
+ for msg in messages:
+ sub_messages = _parse_chat_message_content(
+ msg,
+ mm_tracker,
+ content_format,
+ )
+
+ conversation.extend(sub_messages)
+
+ _postprocess_messages(conversation)
+
+ return conversation, mm_tracker.all_mm_data()
+
+
+def parse_chat_messages_futures(
+ messages: List[ChatCompletionMessageParam],
+ model_config: ModelConfig,
+ tokenizer: AnyTokenizer,
+ content_format: _ChatTemplateContentFormat,
+) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
+ conversation: List[ConversationMessage] = []
+ mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
+
+ for msg in messages:
+ sub_messages = _parse_chat_message_content(
+ msg,
+ mm_tracker,
+ content_format,
+ )
+
+ conversation.extend(sub_messages)
+
+ _postprocess_messages(conversation)
+
+ return conversation, mm_tracker.all_mm_data()
+
+
+def apply_hf_chat_template(
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
+ conversation: List[ConversationMessage],
+ chat_template: Optional[str],
+ *,
+ tokenize: bool = False, # Different from HF's default
+ **kwargs: Any,
+) -> str:
+ if chat_template is None and tokenizer.chat_template is None:
+ raise ValueError(
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one.")
+
+ return tokenizer.apply_chat_template(
+ conversation=conversation, # type: ignore[arg-type]
+ chat_template=chat_template,
+ tokenize=tokenize,
+ **kwargs,
+ )
+
+
+def apply_mistral_chat_template(
+ tokenizer: MistralTokenizer,
+ messages: List[ChatCompletionMessageParam],
+ chat_template: Optional[str] = None,
+ **kwargs: Any,
+) -> List[int]:
+ if chat_template is not None:
+ logger.warning_once(
+ "'chat_template' cannot be overridden for mistral tokenizer.")
+ if "add_generation_prompt" in kwargs:
+ logger.warning_once(
+ "'add_generation_prompt' is not supported for mistral tokenizer, "
+ "so it will be ignored.")
+ if "continue_final_message" in kwargs:
+ logger.warning_once(
+ "'continue_final_message' is not supported for mistral tokenizer, "
+ "so it will be ignored.")
+
+ return tokenizer.apply_chat_template(
+ messages=messages,
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..351a39525fa621870b7d60192478abcd6a5746f2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py
@@ -0,0 +1,105 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import signal
+from http import HTTPStatus
+from typing import Any
+
+import uvicorn
+from fastapi import FastAPI, Request, Response
+
+from vllm import envs
+from vllm.engine.async_llm_engine import AsyncEngineDeadError
+from vllm.engine.multiprocessing import MQEngineDeadError
+from vllm.logger import init_logger
+from vllm.utils import find_process_using_port
+
+logger = init_logger(__name__)
+
+
+async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
+ logger.info("Available routes are:")
+ for route in app.routes:
+ methods = getattr(route, "methods", None)
+ path = getattr(route, "path", None)
+
+ if methods is None or path is None:
+ continue
+
+ logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
+
+ config = uvicorn.Config(app, **uvicorn_kwargs)
+ server = uvicorn.Server(config)
+ _add_shutdown_handlers(app, server)
+
+ loop = asyncio.get_running_loop()
+
+ server_task = loop.create_task(server.serve())
+
+ def signal_handler() -> None:
+ # prevents the uvicorn signal handler to exit early
+ server_task.cancel()
+
+ async def dummy_shutdown() -> None:
+ pass
+
+ loop.add_signal_handler(signal.SIGINT, signal_handler)
+ loop.add_signal_handler(signal.SIGTERM, signal_handler)
+
+ try:
+ await server_task
+ return dummy_shutdown()
+ except asyncio.CancelledError:
+ port = uvicorn_kwargs["port"]
+ process = find_process_using_port(port)
+ if process is not None:
+ logger.debug(
+ "port %s is used by process %s launched with command:\n%s",
+ port, process, " ".join(process.cmdline()))
+ logger.info("Shutting down FastAPI HTTP server.")
+ return server.shutdown()
+
+
+def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
+ """Adds handlers for fatal errors that should crash the server"""
+
+ @app.exception_handler(RuntimeError)
+ async def runtime_error_handler(request: Request, __):
+ """On generic runtime error, check to see if the engine has died.
+ It probably has, in which case the server will no longer be able to
+ handle requests. Trigger a graceful shutdown with a SIGTERM."""
+ engine = request.app.state.engine_client
+ if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
+ and not engine.is_running):
+ logger.fatal("AsyncLLMEngine has failed, terminating server "
+ "process")
+ # See discussions here on shutting down a uvicorn server
+ # https://github.com/encode/uvicorn/discussions/1103
+ # In this case we cannot await the server shutdown here because
+ # this handler must first return to close the connection for
+ # this request.
+ server.should_exit = True
+
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ @app.exception_handler(AsyncEngineDeadError)
+ async def async_engine_dead_handler(_, __):
+ """Kill the server if the async engine is already dead. It will
+ not handle any further requests."""
+ if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
+ logger.fatal("AsyncLLMEngine is already dead, terminating server "
+ "process")
+ server.should_exit = True
+
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ @app.exception_handler(MQEngineDeadError)
+ async def mq_engine_dead_handler(_, __):
+ """Kill the server if the mq engine is already dead. It will
+ not handle any further requests."""
+ if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
+ logger.fatal("MQLLMEngine is already dead, terminating server "
+ "process")
+ server.should_exit = True
+
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d071a0b3cfc5d313bc0ef0055d861eca461acb6a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py
@@ -0,0 +1,1414 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import itertools
+import warnings
+from contextlib import contextmanager
+from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
+ Tuple, Type, Union, cast, overload)
+
+import cloudpickle
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+from typing_extensions import TypeVar, deprecated
+
+from vllm import envs
+from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
+ BeamSearchSequence, get_beam_search_score)
+from vllm.config import CompilationConfig
+from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
+ TaskOption)
+from vllm.engine.llm_engine import LLMEngine
+from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+ ChatTemplateContentFormatOption,
+ apply_hf_chat_template,
+ apply_mistral_chat_template,
+ parse_chat_messages,
+ resolve_chat_template_content_format)
+from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
+from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.model_executor.guided_decoding.guided_fields import (
+ GuidedDecodingRequest, LLMGuidedOptions)
+from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
+ PoolingRequestOutput, RequestOutput,
+ ScoringRequestOutput)
+from vllm.pooling_params import PoolingParams
+from vllm.prompt_adapter.request import PromptAdapterRequest
+from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
+ RequestOutputKind, SamplingParams)
+from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
+ get_cached_tokenizer)
+from vllm.transformers_utils.tokenizer_group import TokenizerGroup
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
+
+logger = init_logger(__name__)
+
+_R = TypeVar("_R", default=Any)
+
+
+class LLM:
+ """An LLM for generating texts from given prompts and sampling parameters.
+
+ This class includes a tokenizer, a language model (possibly distributed
+ across multiple GPUs), and GPU memory space allocated for intermediate
+ states (aka KV cache). Given a batch of prompts and sampling parameters,
+ this class generates texts from the model, using an intelligent batching
+ mechanism and efficient memory management.
+
+ Args:
+ model: The name or path of a HuggingFace Transformers model.
+ tokenizer: The name or path of a HuggingFace Transformers tokenizer.
+ tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
+ if available, and "slow" will always use the slow tokenizer.
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
+ detokenizer. Expect valid prompt_token_ids and None for prompt
+ from the input.
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
+ downloading the model and tokenizer.
+ allowed_local_media_path: Allowing API requests to read local images
+ or videos from directories specified by the server file system.
+ This is a security risk. Should only be enabled in trusted
+ environments.
+ tensor_parallel_size: The number of GPUs to use for distributed
+ execution with tensor parallelism.
+ dtype: The data type for the model weights and activations. Currently,
+ we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
+ the `torch_dtype` attribute specified in the model config file.
+ However, if the `torch_dtype` in the config is `float32`, we will
+ use `float16` instead.
+ quantization: The method used to quantize the model weights. Currently,
+ we support "awq", "gptq", and "fp8" (experimental).
+ If None, we first check the `quantization_config` attribute in the
+ model config file. If that is None, we assume the model weights are
+ not quantized and use `dtype` to determine the data type of
+ the weights.
+ revision: The specific model version to use. It can be a branch name,
+ a tag name, or a commit id.
+ tokenizer_revision: The specific tokenizer version to use. It can be a
+ branch name, a tag name, or a commit id.
+ seed: The seed to initialize the random number generator for sampling.
+ gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
+ reserve for the model weights, activations, and KV cache. Higher
+ values will increase the KV cache size and thus improve the model's
+ throughput. However, if the value is too high, it may cause out-of-
+ memory (OOM) errors.
+ swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
+ This can be used for temporarily storing the states of the requests
+ when their `best_of` sampling parameters are larger than 1. If all
+ requests will have `best_of=1`, you can safely set this to 0.
+ Otherwise, too small values may cause out-of-memory (OOM) errors.
+ cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
+ the model weights. This virtually increases the GPU memory space
+ you can use to hold the model weights, at the cost of CPU-GPU data
+ transfer for every forward pass.
+ enforce_eager: Whether to enforce eager execution. If True, we will
+ disable CUDA graph and always execute the model in eager mode.
+ If False, we will use CUDA graph and eager execution in hybrid.
+ max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
+ When a sequence has context length larger than this, we fall back
+ to eager mode. Additionally for encoder-decoder models, if the
+ sequence length of the encoder input is larger than this, we fall
+ back to the eager mode.
+ disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
+ disable_async_output_proc: Disable async output processing.
+ This may result in lower performance.
+ hf_overrides: If a dictionary, contains arguments to be forwarded to the
+ HuggingFace config. If a callable, it is called to update the
+ HuggingFace config.
+ compilation_config: Either an integer or a dictionary. If it is an
+ integer, it is used as the level of compilation optimization. If it
+ is a dictionary, it can specify the full compilation configuration.
+ **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
+ :ref:`engine-args`)
+
+ Note:
+ This class is intended to be used for offline inference. For online
+ serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
+ """
+
+ DEPRECATE_LEGACY: ClassVar[bool] = True
+ """A flag to toggle whether to deprecate the legacy generate/encode API."""
+
+ DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
+ """
+ A flag to toggle whether to deprecate positional arguments in
+ :meth:`LLM.__init__`.
+ """
+
+ @classmethod
+ @contextmanager
+ def deprecate_legacy_api(cls):
+ cls.DEPRECATE_LEGACY = True
+
+ yield
+
+ cls.DEPRECATE_LEGACY = False
+
+ @deprecate_args(
+ start_index=2, # Ignore self and model
+ is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
+ additional_message=(
+ "All positional arguments other than `model` will be "
+ "replaced with keyword arguments in an upcoming version."),
+ )
+ def __init__(
+ self,
+ model: str,
+ tokenizer: Optional[str] = None,
+ tokenizer_mode: str = "auto",
+ skip_tokenizer_init: bool = False,
+ trust_remote_code: bool = False,
+ allowed_local_media_path: str = "",
+ tensor_parallel_size: int = 1,
+ dtype: str = "auto",
+ quantization: Optional[str] = None,
+ revision: Optional[str] = None,
+ tokenizer_revision: Optional[str] = None,
+ seed: int = 0,
+ gpu_memory_utilization: float = 0.9,
+ swap_space: float = 4,
+ cpu_offload_gb: float = 0,
+ enforce_eager: Optional[bool] = None,
+ max_seq_len_to_capture: int = 8192,
+ disable_custom_all_reduce: bool = False,
+ disable_async_output_proc: bool = False,
+ hf_overrides: Optional[HfOverrides] = None,
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
+ # After positional args are removed, move this right below `model`
+ task: TaskOption = "auto",
+ override_pooler_config: Optional[PoolerConfig] = None,
+ compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
+ **kwargs,
+ ) -> None:
+ '''
+ LLM constructor.
+
+ Note: if enforce_eager is unset (enforce_eager is None)
+ it defaults to False.
+ '''
+
+ if "disable_log_stats" not in kwargs:
+ kwargs["disable_log_stats"] = True
+
+ if "worker_cls" in kwargs:
+ worker_cls = kwargs["worker_cls"]
+ # if the worker_cls is not qualified string name,
+ # we serialize it using cloudpickle to avoid pickling issues
+ if isinstance(worker_cls, type):
+ kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
+
+ if compilation_config is not None:
+ if isinstance(compilation_config, (int, dict)):
+ compilation_config_instance = CompilationConfig.from_cli(
+ str(compilation_config))
+ else:
+ compilation_config_instance = compilation_config
+ else:
+ compilation_config_instance = None
+
+ engine_args = EngineArgs(
+ model=model,
+ task=task,
+ tokenizer=tokenizer,
+ tokenizer_mode=tokenizer_mode,
+ skip_tokenizer_init=skip_tokenizer_init,
+ trust_remote_code=trust_remote_code,
+ allowed_local_media_path=allowed_local_media_path,
+ tensor_parallel_size=tensor_parallel_size,
+ dtype=dtype,
+ quantization=quantization,
+ revision=revision,
+ tokenizer_revision=tokenizer_revision,
+ seed=seed,
+ gpu_memory_utilization=gpu_memory_utilization,
+ swap_space=swap_space,
+ cpu_offload_gb=cpu_offload_gb,
+ enforce_eager=enforce_eager,
+ max_seq_len_to_capture=max_seq_len_to_capture,
+ disable_custom_all_reduce=disable_custom_all_reduce,
+ disable_async_output_proc=disable_async_output_proc,
+ hf_overrides=hf_overrides,
+ mm_processor_kwargs=mm_processor_kwargs,
+ override_pooler_config=override_pooler_config,
+ compilation_config=compilation_config_instance,
+ **kwargs,
+ )
+ # Logic to switch between engines is done at runtime instead of import
+ # to avoid import order issues
+ self.engine_class = self.get_engine_class()
+ self.llm_engine = self.engine_class.from_engine_args(
+ engine_args, usage_context=UsageContext.LLM_CLASS)
+
+ self.request_counter = Counter()
+
+ @staticmethod
+ def get_engine_class() -> Type[LLMEngine]:
+ if envs.VLLM_USE_V1:
+ # Lazy import: the v1 package isn't distributed
+ from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
+ return V1LLMEngine # type: ignore
+ return LLMEngine
+
+ def get_tokenizer(self) -> AnyTokenizer:
+ return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
+
+ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
+ tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
+
+ # While CachedTokenizer is dynamic, have no choice but
+ # compare class name. Misjudgment will arise from
+ # user-defined tokenizer started with 'Cached'
+ if tokenizer.__class__.__name__.startswith("Cached"):
+ tokenizer_group.tokenizer = tokenizer
+ else:
+ tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
+
+ def get_default_sampling_params(self) -> SamplingParams:
+ diff_sampling_param = (
+ self.llm_engine.model_config.get_diff_sampling_param())
+ if diff_sampling_param:
+ return SamplingParams.from_optional(**diff_sampling_param)
+ return SamplingParams()
+
+ @overload
+ def generate(
+ self,
+ prompts: Union[PromptType, Sequence[PromptType]],
+ /,
+ sampling_params: Optional[Union[SamplingParams,
+ Sequence[SamplingParams]]] = None,
+ *,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @overload # LEGACY: single (prompt + optional token ids)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def generate(
+ self,
+ prompts: str,
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ prompt_token_ids: Optional[List[int]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @overload # LEGACY: multi (prompt + optional token ids)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def generate(
+ self,
+ prompts: List[str],
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ prompt_token_ids: Optional[List[List[int]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @overload # LEGACY: single (token ids + optional prompt)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def generate(
+ self,
+ prompts: Optional[str] = None,
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ *,
+ prompt_token_ids: List[int],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @overload # LEGACY: multi (token ids + optional prompt)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def generate(
+ self,
+ prompts: Optional[List[str]] = None,
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ *,
+ prompt_token_ids: List[List[int]],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @overload # LEGACY: single or multi token ids [pos-only]
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def generate(
+ self,
+ prompts: None,
+ sampling_params: None,
+ prompt_token_ids: Union[List[int], List[List[int]]],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ ) -> List[RequestOutput]:
+ ...
+
+ @deprecate_kwargs(
+ "prompt_token_ids",
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+ additional_message="Please use the 'prompts' parameter instead.",
+ )
+ def generate(
+ self,
+ prompts: Union[Union[PromptType, Sequence[PromptType]],
+ Optional[Union[str, List[str]]]] = None,
+ sampling_params: Optional[Union[SamplingParams,
+ Sequence[SamplingParams]]] = None,
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ guided_options_request: Optional[Union[LLMGuidedOptions,
+ GuidedDecodingRequest]] = None,
+ priority: Optional[List[int]] = None,
+ ) -> List[RequestOutput]:
+ """Generates the completions for the input prompts.
+
+ This class automatically batches the given prompts, considering
+ the memory constraint. For the best performance, put all of your prompts
+ into a single list and pass it to this method.
+
+ Args:
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
+ for batch inference. See :class:`~vllm.inputs.PromptType`
+ for more details about the format of each prompts.
+ sampling_params: The sampling parameters for text generation. If
+ None, we use the default sampling parameters.
+ When it is a single value, it is applied to every prompt.
+ When it is a list, the list must have the same length as the
+ prompts and it is paired one by one with the prompt.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ prompt_adapter_request: Prompt Adapter request to use for
+ generation, if any.
+ priority: The priority of the requests, if any.
+ Only applicable when priority scheduling policy is enabled.
+
+ Returns:
+ A list of ``RequestOutput`` objects containing the
+ generated completions in the same order as the input prompts.
+
+ Note:
+ Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
+ considered legacy and may be deprecated in the future. You should
+ instead pass them via the ``inputs`` parameter.
+ """
+ runner_type = self.llm_engine.model_config.runner_type
+ if runner_type != "generate":
+ messages = [
+ "LLM.generate() is only supported for (conditional) generation "
+ "models (XForCausalLM, XForConditionalGeneration).",
+ ]
+
+ supported_runner_types = self.llm_engine.model_config \
+ .supported_runner_types
+ if "generate" in supported_runner_types:
+ messages.append(
+ "Your model supports the 'generate' runner, but is "
+ f"currently initialized for the '{runner_type}' runner. "
+ "Please initialize vLLM using `--task generate`.")
+
+ raise ValueError(" ".join(messages))
+
+ if prompt_token_ids is not None:
+ parsed_prompts = self._convert_v1_inputs(
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
+ prompt_token_ids=prompt_token_ids,
+ )
+ else:
+ parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
+ prompts)
+
+ if isinstance(guided_options_request, dict):
+ if len(guided_options_request) > 1:
+ raise ValueError(
+ "You can only use one guided decoding but multiple is "
+ f"specified: {guided_options_request}")
+ guided_options_request = GuidedDecodingRequest(
+ **guided_options_request)
+
+ if sampling_params is None:
+ # Use default sampling params.
+ sampling_params = self.get_default_sampling_params()
+
+ self._validate_and_add_requests(
+ prompts=parsed_prompts,
+ params=sampling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ guided_options=guided_options_request,
+ priority=priority)
+
+ outputs = self._run_engine(use_tqdm=use_tqdm)
+ return self.engine_class.validate_outputs(outputs, RequestOutput)
+
+ def collective_rpc(self,
+ method: Union[str, Callable[..., _R]],
+ timeout: Optional[float] = None,
+ args: Tuple = (),
+ kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
+ """
+ Execute an RPC call on all workers.
+
+ Args:
+ method: Name of the worker method to execute, or a callable that
+ is serialized and sent to all workers to execute.
+
+ If the method is a callable, it should accept an additional
+ `self` argument, in addition to the arguments passed in `args`
+ and `kwargs`. The `self` argument will be the worker object.
+ timeout: Maximum time in seconds to wait for execution. Raises a
+ :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
+ args: Positional arguments to pass to the worker method.
+ kwargs: Keyword arguments to pass to the worker method.
+
+ Returns:
+ A list containing the results from each worker.
+
+ Note:
+ It is recommended to use this API to only pass control messages,
+ and set up data-plane communication to pass data.
+ """
+ executor = self.llm_engine.model_executor
+ return executor.collective_rpc(method, timeout, args, kwargs)
+
+ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
+ """
+ Run a function directly on the model inside each worker,
+ returning the result for each of them.
+ """
+ executor = self.llm_engine.model_executor
+ return executor.apply_model(func)
+
+ def beam_search(
+ self,
+ prompts: List[Union[TokensPrompt, TextPrompt]],
+ params: BeamSearchParams,
+ ) -> List[BeamSearchOutput]:
+ """
+ Generate sequences using beam search.
+
+ Args:
+ prompts: A list of prompts. Each prompt can be a string or a list
+ of token IDs.
+ params: The beam search parameters.
+
+ TODO: how does beam search work together with length penalty, frequency
+ penalty, and stopping criteria, etc.?
+ """
+
+ beam_width = params.beam_width
+ max_tokens = params.max_tokens
+ temperature = params.temperature
+ ignore_eos = params.ignore_eos
+ length_penalty = params.length_penalty
+
+ def sort_beams_key(x: BeamSearchSequence) -> float:
+ return get_beam_search_score(x.tokens, x.cum_logprob,
+ tokenizer.eos_token_id,
+ length_penalty)
+
+ tokenizer = self.get_tokenizer()
+ # generate 2 * beam_width candidates at each step
+ # following the huggingface transformers implementation
+ # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
+ beam_search_params = SamplingParams(logprobs=2 * beam_width,
+ max_tokens=1,
+ temperature=temperature)
+ instances: List[BeamSearchInstance] = []
+
+ for prompt in prompts:
+ if is_token_prompt(prompt):
+ prompt_tokens = prompt["prompt_token_ids"]
+ else:
+ prompt_tokens = tokenizer.encode(prompt["prompt"])
+ instances.append(BeamSearchInstance(prompt_tokens))
+
+ for _ in range(max_tokens):
+ all_beams: List[BeamSearchSequence] = list(
+ sum((instance.beams for instance in instances), []))
+ pos = [0] + list(
+ itertools.accumulate(
+ len(instance.beams) for instance in instances))
+ instance_start_and_end: List[Tuple[int, int]] = list(
+ zip(pos[:-1], pos[1:]))
+
+ if len(all_beams) == 0:
+ break
+
+ prompts_batch = [
+ TokensPrompt(prompt_token_ids=beam.tokens)
+ for beam in all_beams
+ ]
+
+ # only runs for one step
+ # we don't need to use tqdm here
+ output = self.generate(prompts_batch,
+ sampling_params=beam_search_params,
+ use_tqdm=False)
+
+ for (start, end), instance in zip(instance_start_and_end,
+ instances):
+ instance_new_beams = []
+ for i in range(start, end):
+ current_beam = all_beams[i]
+ result = output[i]
+
+ if result.outputs[0].logprobs is not None:
+ # if `result.outputs[0].logprobs` is None, it means
+ # the sequence is completed because of the max-model-len
+ # or abortion. we don't need to add it to the new beams.
+ logprobs = result.outputs[0].logprobs[0]
+ for token_id, logprob_obj in logprobs.items():
+ new_beam = BeamSearchSequence(
+ tokens=current_beam.tokens + [token_id],
+ logprobs=current_beam.logprobs + [logprobs],
+ cum_logprob=current_beam.cum_logprob +
+ logprob_obj.logprob)
+
+ if token_id == tokenizer.eos_token_id and \
+ not ignore_eos:
+ instance.completed.append(new_beam)
+ else:
+ instance_new_beams.append(new_beam)
+ sorted_beams = sorted(instance_new_beams,
+ key=sort_beams_key,
+ reverse=True)
+ instance.beams = sorted_beams[:beam_width]
+
+ outputs = []
+ for instance in instances:
+ instance.completed.extend(instance.beams)
+ sorted_completed = sorted(instance.completed,
+ key=sort_beams_key,
+ reverse=True)
+ best_beams = sorted_completed[:beam_width]
+
+ for beam in best_beams:
+ beam.text = tokenizer.decode(beam.tokens)
+ outputs.append(BeamSearchOutput(sequences=best_beams))
+
+ return outputs
+
+ def chat(
+ self,
+ messages: Union[List[ChatCompletionMessageParam],
+ List[List[ChatCompletionMessageParam]]],
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[LoRARequest] = None,
+ chat_template: Optional[str] = None,
+ chat_template_content_format: ChatTemplateContentFormatOption = "auto",
+ add_generation_prompt: bool = True,
+ continue_final_message: bool = False,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> List[RequestOutput]:
+ """
+ Generate responses for a chat conversation.
+
+ The chat conversation is converted into a text prompt using the
+ tokenizer and calls the :meth:`generate` method to generate the
+ responses.
+
+ Multi-modal inputs can be passed in the same way you would pass them
+ to the OpenAI API.
+
+ Args:
+ messages: A list of conversations or a single conversation.
+
+ - Each conversation is represented as a list of messages.
+ - Each message is a dictionary with 'role' and 'content' keys.
+
+ sampling_params: The sampling parameters for text generation.
+ If None, we use the default sampling parameters. When it
+ is a single value, it is applied to every prompt. When it
+ is a list, the list must have the same length as the
+ prompts and it is paired one by one with the prompt.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ chat_template: The template to use for structuring the chat.
+ If not provided, the model's default chat template will be used.
+ chat_template_content_format: The format to render message content.
+
+ - "string" will render the content as a string.
+ Example: ``"Who are you?"``
+ - "openai" will render the content as a list of dictionaries,
+ similar to OpenAI schema.
+ Example: ``[{"type": "text", "text": "Who are you?"}]``
+
+ add_generation_prompt: If True, adds a generation template
+ to each message.
+ continue_final_message: If True, continues the final message in
+ the conversation instead of starting a new one. Cannot be
+ ``True`` if ``add_generation_prompt`` is also ``True``.
+ mm_processor_kwargs: Multimodal processor kwarg overrides for this
+ chat request. Only used for offline requests.
+
+ Returns:
+ A list of ``RequestOutput`` objects containing the generated
+ responses in the same order as the input messages.
+ """
+ list_of_messages: List[List[ChatCompletionMessageParam]]
+
+ # Handle multi and single conversations
+ if is_list_of(messages, list):
+ # messages is List[List[...]]
+ list_of_messages = cast(List[List[ChatCompletionMessageParam]],
+ messages)
+ else:
+ # messages is List[...]
+ list_of_messages = [
+ cast(List[ChatCompletionMessageParam], messages)
+ ]
+
+ tokenizer = self.get_tokenizer()
+ model_config = self.llm_engine.get_model_config()
+ resolved_content_format = resolve_chat_template_content_format(
+ chat_template,
+ chat_template_content_format,
+ tokenizer,
+ )
+
+ prompts: List[Union[TokensPrompt, TextPrompt]] = []
+
+ for msgs in list_of_messages:
+ # NOTE: _parse_chat_message_content_parts() currently doesn't
+ # handle mm_processor_kwargs, since there is no implementation in
+ # the chat message parsing for it.
+ conversation, mm_data = parse_chat_messages(
+ msgs,
+ model_config,
+ tokenizer,
+ content_format=resolved_content_format,
+ )
+
+ prompt_data: Union[str, List[int]]
+ if isinstance(tokenizer, MistralTokenizer):
+ prompt_data = apply_mistral_chat_template(
+ tokenizer,
+ messages=msgs,
+ chat_template=chat_template,
+ add_generation_prompt=add_generation_prompt,
+ continue_final_message=continue_final_message,
+ tools=tools,
+ )
+ else:
+ prompt_data = apply_hf_chat_template(
+ tokenizer,
+ conversation=conversation,
+ chat_template=chat_template,
+ add_generation_prompt=add_generation_prompt,
+ continue_final_message=continue_final_message,
+ tools=tools,
+ )
+
+ prompt: Union[TokensPrompt, TextPrompt]
+ if is_list_of(prompt_data, int):
+ prompt = TokensPrompt(prompt_token_ids=prompt_data)
+ else:
+ prompt = TextPrompt(prompt=prompt_data)
+
+ if mm_data is not None:
+ prompt["multi_modal_data"] = mm_data
+
+ if mm_processor_kwargs is not None:
+ prompt["mm_processor_kwargs"] = mm_processor_kwargs
+
+ prompts.append(prompt)
+
+ return self.generate(
+ prompts,
+ sampling_params=sampling_params,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ )
+
+ @overload
+ def encode(
+ self,
+ prompts: Union[PromptType, Sequence[PromptType]],
+ /,
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ *,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @overload # LEGACY: single (prompt + optional token ids)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def encode(
+ self,
+ prompts: str,
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ prompt_token_ids: Optional[List[int]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @overload # LEGACY: multi (prompt + optional token ids)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def encode(
+ self,
+ prompts: List[str],
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ prompt_token_ids: Optional[List[List[int]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @overload # LEGACY: single (token ids + optional prompt)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def encode(
+ self,
+ prompts: Optional[str] = None,
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ *,
+ prompt_token_ids: List[int],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @overload # LEGACY: multi (token ids + optional prompt)
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def encode(
+ self,
+ prompts: Optional[List[str]] = None,
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ *,
+ prompt_token_ids: List[List[int]],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @overload # LEGACY: single or multi token ids [pos-only]
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
+ def encode(
+ self,
+ prompts: None,
+ pooling_params: None,
+ prompt_token_ids: Union[List[int], List[List[int]]],
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ ...
+
+ @deprecate_kwargs(
+ "prompt_token_ids",
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+ additional_message="Please use the 'prompts' parameter instead.",
+ )
+ def encode(
+ self,
+ prompts: Union[Union[PromptType, Sequence[PromptType]],
+ Optional[Union[str, List[str]]]] = None,
+ pooling_params: Optional[Union[PoolingParams,
+ Sequence[PoolingParams]]] = None,
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[PoolingRequestOutput]:
+ """Apply pooling to the hidden states corresponding to the input
+ prompts.
+
+ This class automatically batches the given prompts, considering
+ the memory constraint. For the best performance, put all of your prompts
+ into a single list and pass it to this method.
+
+ Args:
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
+ for batch inference. See :class:`~vllm.inputs.PromptType`
+ for more details about the format of each prompts.
+ pooling_params: The pooling parameters for pooling. If None, we
+ use the default pooling parameters.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ prompt_adapter_request: Prompt Adapter request to use for
+ generation, if any.
+
+ Returns:
+ A list of ``PoolingRequestOutput`` objects containing the
+ pooled hidden states in the same order as the input prompts.
+
+ Note:
+ Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
+ considered legacy and may be deprecated in the future. You should
+ instead pass them via the ``inputs`` parameter.
+ """
+ runner_type = self.llm_engine.model_config.runner_type
+ if runner_type != "pooling":
+ messages = ["LLM.encode() is only supported for pooling models."]
+
+ supported_runner_types = self.llm_engine.model_config \
+ .supported_runner_types
+ if "pooling" in supported_runner_types:
+ messages.append(
+ "Your model supports the 'pooling' runner, but is "
+ f"currently initialized for the '{runner_type}' runner. "
+ "Please initialize vLLM using `--task embed`, "
+ "`--task classify`, `--task score` etc.")
+
+ raise ValueError(" ".join(messages))
+
+ if prompt_token_ids is not None:
+ parsed_prompts = self._convert_v1_inputs(
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
+ prompt_token_ids=prompt_token_ids,
+ )
+ else:
+ parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
+ prompts)
+
+ if pooling_params is None:
+ # Use default pooling params.
+ pooling_params = PoolingParams()
+
+ self._validate_and_add_requests(
+ prompts=parsed_prompts,
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ )
+
+ outputs = self._run_engine(use_tqdm=use_tqdm)
+ return self.engine_class.validate_outputs(outputs,
+ PoolingRequestOutput)
+
+ def embed(
+ self,
+ prompts: Union[PromptType, Sequence[PromptType]],
+ /,
+ *,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[EmbeddingRequestOutput]:
+ """
+ Generate an embedding vector for each prompt.
+
+ This class automatically batches the given prompts, considering
+ the memory constraint. For the best performance, put all of your prompts
+ into a single list and pass it to this method.
+
+ Args:
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
+ for batch inference. See :class:`~vllm.inputs.PromptType`
+ for more details about the format of each prompts.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ prompt_adapter_request: Prompt Adapter request to use for
+ generation, if any.
+
+ Returns:
+ A list of ``EmbeddingRequestOutput`` objects containing the
+ embedding vectors in the same order as the input prompts.
+ """
+ if self.llm_engine.model_config.task != "embed":
+ raise ValueError(
+ "Embedding API is only enabled for `--task embed`")
+
+ items = self.encode(prompts,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ return [EmbeddingRequestOutput.from_base(item) for item in items]
+
+ def classify(
+ self,
+ prompts: Union[PromptType, Sequence[PromptType]],
+ /,
+ *,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[ClassificationRequestOutput]:
+ """
+ Generate class logits for each prompt.
+
+ This class automatically batches the given prompts, considering
+ the memory constraint. For the best performance, put all of your prompts
+ into a single list and pass it to this method.
+
+ Args:
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
+ for batch inference. See :class:`~vllm.inputs.PromptType`
+ for more details about the format of each prompts.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ prompt_adapter_request: Prompt Adapter request to use for
+ generation, if any.
+
+ Returns:
+ A list of ``ClassificationRequestOutput`` objects containing the
+ embedding vectors in the same order as the input prompts.
+ """
+ if self.llm_engine.model_config.task != "classify":
+ raise ValueError(
+ "Classification API is only enabled for `--task classify`")
+
+ items = self.encode(prompts,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ return [ClassificationRequestOutput.from_base(item) for item in items]
+
+ def _embedding_score(
+ self,
+ tokenizer: AnyTokenizer,
+ text_1: List[Union[str, TextPrompt, TokensPrompt]],
+ text_2: List[Union[str, TextPrompt, TokensPrompt]],
+ truncate_prompt_tokens: Optional[int] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[ScoringRequestOutput]:
+
+ encoded_output = self.encode(
+ text_1 + text_2,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+ encoded_output_1 = encoded_output[0:len(text_1)]
+ encoded_output_2 = encoded_output[len(text_1):]
+
+ if len(encoded_output_1) == 1:
+ encoded_output_1 = encoded_output_1 * len(encoded_output_2)
+
+ output_pairs = [(t1, t2)
+ for t1, t2 in zip(encoded_output_1, encoded_output_2)]
+
+ scores = []
+ scorer = torch.nn.CosineSimilarity(0)
+
+ for embed_1, embed_2 in output_pairs:
+ pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
+
+ if (pad_token_id := getattr(tokenizer, "pad_token_id",
+ None)) is not None:
+ tokens = embed_1.prompt_token_ids + [
+ pad_token_id
+ ] + embed_2.prompt_token_ids
+ else:
+ tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
+
+ scores.append(
+ PoolingRequestOutput(
+ request_id=f"{embed_1.request_id}_{embed_2.request_id}",
+ outputs=pair_score,
+ prompt_token_ids=tokens,
+ finished=True))
+
+ items = self.engine_class.validate_outputs(scores,
+ PoolingRequestOutput)
+ return [ScoringRequestOutput.from_base(item) for item in items]
+
+ def _cross_encoding_score(
+ self,
+ tokenizer: Union[AnyTokenizer],
+ text_1: List[Union[str, TextPrompt, TokensPrompt]],
+ text_2: List[Union[str, TextPrompt, TokensPrompt]],
+ truncate_prompt_tokens: Optional[int] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[ScoringRequestOutput]:
+
+ if isinstance(tokenizer, MistralTokenizer):
+ raise ValueError(
+ "Score API is only enabled for `--task embed or score`")
+
+ if len(text_1) == 1:
+ text_1 = text_1 * len(text_2)
+
+ input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
+
+ pooling_params = PoolingParams()
+
+ tokenization_kwargs: Dict[str, Any] = {}
+ if truncate_prompt_tokens is not None:
+ tokenization_kwargs["truncation"] = True
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
+
+ parsed_prompts = []
+
+ for q, t in input_pairs:
+ prompt_inputs = tokenizer(text=q,
+ text_pair=t,
+ **tokenization_kwargs)
+ engine_prompt = TokensPrompt(
+ prompt_token_ids=prompt_inputs["input_ids"],
+ token_type_ids=prompt_inputs.get("token_type_ids"))
+ parsed_prompts.append(engine_prompt)
+
+ self._validate_and_add_requests(
+ prompts=parsed_prompts,
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ )
+
+ outputs = self._run_engine(use_tqdm=use_tqdm)
+ items = self.engine_class.validate_outputs(outputs,
+ PoolingRequestOutput)
+
+ return [ScoringRequestOutput.from_base(item) for item in items]
+
+ def score(
+ self,
+ text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
+ text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
+ /,
+ *,
+ truncate_prompt_tokens: Optional[int] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> List[ScoringRequestOutput]:
+ """Generate similarity scores for all pairs ````.
+
+ The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
+ In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
+ times to pair with the ``text_2`` sentences.
+ The input pairs are used to build a list of prompts for the
+ cross encoder model. This class automatically batches the prompts,
+ considering the memory constraint. For the best performance, put all
+ of your texts into a single list and pass it to this method.
+
+ Args:
+ text_1: can be a single prompt or a list of prompts, in which
+ case it has to have the same length as the ``text_2`` list
+ text_2: The texts to pair with the query to form the input
+ to the LLM. See :class:`~vllm.inputs.PromptType` for
+ more details about the format of each prompts.
+ use_tqdm: Whether to use tqdm to display the progress bar.
+ lora_request: LoRA request to use for generation, if any.
+ prompt_adapter_request: Prompt Adapter request to use for
+ generation, if any.
+
+ Returns:
+ A list of ``ScoringRequestOutput`` objects containing the
+ generated scores in the same order as the input prompts.
+ """
+ runner_type = self.llm_engine.model_config.runner_type
+ if runner_type != "pooling":
+ messages = ["LLM.score() is only supported for pooling models."]
+
+ supported_runner_types = self.llm_engine.model_config \
+ .supported_runner_types
+ if "pooling" in supported_runner_types:
+ messages.append(
+ "Your model supports the 'pooling' runner, but is "
+ f"currently initialized for the '{runner_type}' runner. "
+ "Please initialize vLLM using `--task embed`, "
+ "`--task classify`, `--task score` etc.")
+
+ raise ValueError(" ".join(messages))
+
+ if self.llm_engine.model_config.task not in ("embed", "score"):
+ raise ValueError(
+ "Score API is only enabled for `--task embed or --task score`")
+
+ # the tokenizer for models such as
+ # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
+ # lists of tokens to the `text` and `text_pair` kwargs
+ tokenizer = self.llm_engine.get_tokenizer()
+
+ def ensure_str(prompt: SingletonPrompt):
+ if isinstance(prompt, dict):
+ if "multi_modal_data" in prompt:
+ raise ValueError("Multi-modal prompt is not "
+ "supported for scoring")
+ elif "prompt_token_ids" in prompt:
+ prompt = tokenizer.decode(
+ cast(TokensPrompt, prompt)["prompt_token_ids"])
+ elif "prompt" in prompt:
+ prompt = cast(TextPrompt, prompt)["prompt"]
+ assert type(prompt) is str
+ return prompt
+
+ if isinstance(text_1, (str, dict)):
+ # Convert a single prompt to a list.
+ text_1 = [text_1]
+ text_1 = [ensure_str(t) for t in text_1]
+
+ if isinstance(text_2, (str, dict)):
+ # Convert a single prompt to a list.
+ text_2 = [text_2]
+ text_2 = [ensure_str(t) for t in text_2]
+
+ if len(text_1) > 1 and len(text_1) != len(text_2):
+ raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
+ if len(text_1) == 0:
+ raise ValueError("At least one text element must be given")
+ if len(text_2) == 0:
+ raise ValueError("At least one text_pair element must be given")
+
+ if self.llm_engine.model_config.is_cross_encoder:
+ return self._cross_encoding_score(tokenizer, text_1, text_2,
+ truncate_prompt_tokens, use_tqdm,
+ lora_request,
+ prompt_adapter_request)
+ else:
+ return self._embedding_score(tokenizer, text_1, text_2,
+ truncate_prompt_tokens, use_tqdm,
+ lora_request, prompt_adapter_request)
+
+ def start_profile(self) -> None:
+ self.llm_engine.start_profile()
+
+ def stop_profile(self) -> None:
+ self.llm_engine.stop_profile()
+
+ def reset_prefix_cache(self) -> bool:
+ return self.llm_engine.reset_prefix_cache()
+
+ def sleep(self, level: int = 1):
+ """
+ Put the engine to sleep. The engine should not process any requests.
+ The caller should guarantee that no requests are being processed
+ during the sleep period, before `wake_up` is called.
+
+ :param level: The sleep level. Level 1 sleep will offload the model
+ weights and discard the kv cache. The content of kv cache is
+ forgotten. Level 1 sleep is good for sleeping and waking up the
+ engine to run the same model again. The model weights are backed
+ up in CPU memory. Please make sure there's enough CPU memory to
+ store the model weights. Level 2 sleep will discard both the model
+ weights and the kv cache. The content of both the model weights
+ and kv cache is forgotten. Level 2 sleep is good for sleeping and
+ waking up the engine to run a different model or update the model,
+ where previous model weights are not needed. It reduces CPU memory
+ pressure.
+ """
+ self.reset_prefix_cache()
+ self.llm_engine.sleep(level=level)
+
+ def wake_up(self):
+ """
+ Wake up the engine from sleep mode. See the :meth:`sleep` method
+ for more details."""
+ self.llm_engine.wake_up()
+
+ # LEGACY
+ def _convert_v1_inputs(
+ self,
+ prompts: Optional[Union[str, List[str]]],
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
+ ):
+ # skip_tokenizer_init is now checked in engine
+
+ if prompts is not None:
+ prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
+ if prompt_token_ids is not None:
+ prompt_token_ids = [
+ p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
+ ]
+
+ num_requests = None
+ if prompts is not None:
+ num_requests = len(prompts)
+ if prompt_token_ids is not None:
+ if (num_requests is not None
+ and num_requests != len(prompt_token_ids)):
+ raise ValueError("The lengths of prompts and prompt_token_ids "
+ "must be the same.")
+
+ num_requests = len(prompt_token_ids)
+ if num_requests is None:
+ raise ValueError("Either prompts or prompt_token_ids must be "
+ "provided.")
+
+ parsed_prompts: List[PromptType] = []
+ for i in range(num_requests):
+ item: PromptType
+
+ if prompts is not None:
+ item = TextPrompt(prompt=prompts[i])
+ elif prompt_token_ids is not None:
+ item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
+ else:
+ raise AssertionError
+
+ parsed_prompts.append(item)
+
+ return parsed_prompts
+
+ def _validate_and_add_requests(
+ self,
+ prompts: Union[PromptType, Sequence[PromptType]],
+ params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
+ Sequence[PoolingParams]],
+ lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
+ prompt_adapter_request: Optional[PromptAdapterRequest],
+ guided_options: Optional[GuidedDecodingRequest] = None,
+ priority: Optional[List[int]] = None,
+ ) -> None:
+ if guided_options is not None:
+ warnings.warn(
+ "guided_options_request is deprecated, use "
+ "SamplingParams.guided_decoding instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ if isinstance(prompts, (str, dict)):
+ # Convert a single prompt to a list.
+ prompts = [prompts]
+
+ num_requests = len(prompts)
+ if isinstance(params, list) and len(params) != num_requests:
+ raise ValueError("The lengths of prompts and params "
+ "must be the same.")
+ if isinstance(lora_request,
+ list) and len(lora_request) != num_requests:
+ raise ValueError("The lengths of prompts and lora_request "
+ "must be the same.")
+
+ for sp in params if isinstance(params, list) else (params, ):
+ if isinstance(sp, SamplingParams):
+ self._add_guided_params(sp, guided_options)
+
+ # We only care about the final output
+ sp.output_kind = RequestOutputKind.FINAL_ONLY
+
+ # Add requests to the engine.
+ for i, prompt in enumerate(prompts):
+ self._add_request(
+ prompt,
+ params[i] if isinstance(params, Sequence) else params,
+ lora_request=lora_request[i] if isinstance(
+ lora_request, Sequence) else lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ priority=priority[i] if priority else 0,
+ )
+
+ def _add_request(
+ self,
+ prompt: PromptType,
+ params: Union[SamplingParams, PoolingParams],
+ lora_request: Optional[LoRARequest] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ priority: int = 0,
+ ) -> None:
+ request_id = str(next(self.request_counter))
+ self.llm_engine.add_request(
+ request_id,
+ prompt,
+ params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ priority=priority,
+ )
+
+ def _add_guided_params(
+ self,
+ params: SamplingParams,
+ guided_options: Optional[GuidedDecodingRequest] = None):
+ if guided_options is None:
+ return params
+
+ if params.guided_decoding is not None:
+ raise ValueError("Cannot set both guided_options_request and"
+ "params.guided_decoding.")
+
+ params.guided_decoding = GuidedDecodingParams(
+ json=guided_options.guided_json,
+ regex=guided_options.guided_regex,
+ choice=guided_options.guided_choice,
+ grammar=guided_options.guided_grammar,
+ json_object=guided_options.guided_json_object,
+ backend=guided_options.guided_decoding_backend,
+ whitespace_pattern=guided_options.guided_whitespace_pattern)
+ return params
+
+ def _run_engine(
+ self, *, use_tqdm: bool
+ ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
+ # Initialize tqdm.
+ if use_tqdm:
+ num_requests = self.llm_engine.get_num_unfinished_requests()
+ pbar = tqdm(
+ total=num_requests,
+ desc="Processed prompts",
+ dynamic_ncols=True,
+ postfix=(f"est. speed input: {0:.2f} toks/s, "
+ f"output: {0:.2f} toks/s"),
+ )
+
+ # Run the engine.
+ outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
+ total_in_toks = 0
+ total_out_toks = 0
+ while self.llm_engine.has_unfinished_requests():
+ step_outputs = self.llm_engine.step()
+ for output in step_outputs:
+ if output.finished:
+ outputs.append(output)
+ if use_tqdm:
+ if isinstance(output, RequestOutput):
+ # Calculate tokens only for RequestOutput
+ assert output.prompt_token_ids is not None
+ total_in_toks += len(output.prompt_token_ids)
+ in_spd = total_in_toks / pbar.format_dict["elapsed"]
+ total_out_toks += sum(
+ len(stp.token_ids) for stp in output.outputs)
+ out_spd = (total_out_toks /
+ pbar.format_dict["elapsed"])
+ pbar.postfix = (
+ f"est. speed input: {in_spd:.2f} toks/s, "
+ f"output: {out_spd:.2f} toks/s")
+ pbar.update(1)
+
+ if use_tqdm:
+ pbar.close()
+ # Sort the outputs by request ID.
+ # This is necessary because some requests may be finished earlier than
+ # its previous requests.
+ return sorted(outputs, key=lambda x: int(x.request_id))
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..e82b6ba6c7bae3c0496f395324b4a405dc34c435
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py
@@ -0,0 +1,44 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Optional, Union
+
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.pooling_params import PoolingParams
+from vllm.prompt_adapter.request import PromptAdapterRequest
+from vllm.sampling_params import BeamSearchParams, SamplingParams
+
+logger = init_logger(__name__)
+
+
+class RequestLogger:
+
+ def __init__(self, *, max_log_len: Optional[int]) -> None:
+ super().__init__()
+
+ self.max_log_len = max_log_len
+
+ def log_inputs(
+ self,
+ request_id: str,
+ prompt: Optional[str],
+ prompt_token_ids: Optional[List[int]],
+ params: Optional[Union[SamplingParams, PoolingParams,
+ BeamSearchParams]],
+ lora_request: Optional[LoRARequest],
+ prompt_adapter_request: Optional[PromptAdapterRequest],
+ ) -> None:
+ max_log_len = self.max_log_len
+ if max_log_len is not None:
+ if prompt is not None:
+ prompt = prompt[:max_log_len]
+
+ if prompt_token_ids is not None:
+ prompt_token_ids = prompt_token_ids[:max_log_len]
+
+ logger.info(
+ "Received request %s: prompt: %r, "
+ "params: %s, prompt_token_ids: %s, "
+ "lora_request: %s, prompt_adapter_request: %s.", request_id,
+ prompt, params, prompt_token_ids, lora_request,
+ prompt_adapter_request)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..330e65ab7feb7a2db9dcc3f2366ed207683ad9c7
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7310b71eb9a9c9b0ff15bc85ec7b115c541aa57
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d6937ebc725460735b31517653dfb3e2eb47f6f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..babb3d685ce86bd72d7050fa1e121b4b98d87822
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86c6e91c74ca0b2036ecafd74e24e84b464ad9ff
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..690ed28b842ebb3c176133d30fe6c68820ae134a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3aaec581e7664e9c51a05cd41fdc879b3e8552c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39e1468d7d795a72d041e35cde4d263e9644bbe2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19f5b2291b1df5efe239aaecb32ff8b5be6eb592
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..752568888dc2ceb5a4efebed1b24d59a264ae57f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c34c10a91572ae85f625681d18bd15958c5af810
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1316222cb6225825dc5fe78f8e50f22643ad396
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79ba6c4a759b007ccf8216c54332bf6c9269b2a7
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c42fd11416eda0363a2ddf6904ca5be5e45498f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..302849af071e2523826c59f33850bab7ba10c55f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8f54d6c78042dd944bd5410dd3a12ded13fa877
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py
@@ -0,0 +1,911 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import atexit
+import gc
+import importlib
+import inspect
+import multiprocessing
+import os
+import re
+import signal
+import socket
+import sys
+import tempfile
+import uuid
+from argparse import Namespace
+from contextlib import asynccontextmanager
+from functools import partial
+from http import HTTPStatus
+from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
+
+import uvloop
+from fastapi import APIRouter, FastAPI, HTTPException, Request
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse, Response, StreamingResponse
+from starlette.datastructures import State
+from starlette.routing import Mount
+from typing_extensions import assert_never
+
+import vllm.envs as envs
+from vllm.config import ModelConfig
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
+from vllm.engine.multiprocessing.client import MQLLMEngineClient
+from vllm.engine.multiprocessing.engine import run_mp_engine
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.launcher import serve_http
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.cli_args import (make_arg_parser,
+ validate_parsed_serve_args)
+# yapf conflicts with isort for this block
+# yapf: disable
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ ChatCompletionResponse,
+ CompletionRequest,
+ CompletionResponse,
+ DetokenizeRequest,
+ DetokenizeResponse,
+ EmbeddingChatRequest,
+ EmbeddingCompletionRequest,
+ EmbeddingRequest,
+ EmbeddingResponse,
+ EmbeddingResponseData,
+ ErrorResponse,
+ LoadLoraAdapterRequest,
+ PoolingChatRequest,
+ PoolingCompletionRequest,
+ PoolingRequest, PoolingResponse,
+ RerankRequest, RerankResponse,
+ ScoreRequest, ScoreResponse,
+ TokenizeRequest,
+ TokenizeResponse,
+ UnloadLoraAdapterRequest)
+from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
+# yapf: enable
+from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
+from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
+from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import (BaseModelPath,
+ OpenAIServingModels)
+from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
+from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
+from vllm.entrypoints.openai.serving_score import OpenAIServingScores
+from vllm.entrypoints.openai.serving_tokenization import (
+ OpenAIServingTokenization)
+from vllm.entrypoints.openai.tool_parsers import ToolParserManager
+from vllm.entrypoints.utils import with_cancellation
+from vllm.logger import init_logger
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
+ is_valid_ipv6_address, set_ulimit)
+from vllm.version import __version__ as VLLM_VERSION
+
+TIMEOUT_KEEP_ALIVE = 5 # seconds
+
+prometheus_multiproc_dir: tempfile.TemporaryDirectory
+
+# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
+logger = init_logger('vllm.entrypoints.openai.api_server')
+
+_running_tasks: Set[asyncio.Task] = set()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ try:
+ if app.state.log_stats:
+ engine_client: EngineClient = app.state.engine_client
+
+ async def _force_log():
+ while True:
+ await asyncio.sleep(10.)
+ await engine_client.do_log_stats()
+
+ task = asyncio.create_task(_force_log())
+ _running_tasks.add(task)
+ task.add_done_callback(_running_tasks.remove)
+ else:
+ task = None
+
+ # Mark the startup heap as static so that it's ignored by GC.
+ # Reduces pause times of oldest generation collections.
+ gc.collect()
+ gc.freeze()
+ try:
+ yield
+ finally:
+ if task is not None:
+ task.cancel()
+ finally:
+ # Ensure app state including engine ref is gc'd
+ del app.state
+
+
+@asynccontextmanager
+async def build_async_engine_client(
+ args: Namespace) -> AsyncIterator[EngineClient]:
+
+ # Context manager to handle engine_client lifecycle
+ # Ensures everything is shutdown and cleaned up on error/exit
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+
+ async with build_async_engine_client_from_engine_args(
+ engine_args, args.disable_frontend_multiprocessing) as engine:
+ yield engine
+
+
+@asynccontextmanager
+async def build_async_engine_client_from_engine_args(
+ engine_args: AsyncEngineArgs,
+ disable_frontend_multiprocessing: bool = False,
+) -> AsyncIterator[EngineClient]:
+ """
+ Create EngineClient, either:
+ - in-process using the AsyncLLMEngine Directly
+ - multiprocess using AsyncLLMEngine RPC
+
+ Returns the Client or None if the creation failed.
+ """
+
+ # AsyncLLMEngine.
+ if (MQLLMEngineClient.is_unsupported_config(engine_args)
+ or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
+
+ engine_client: Optional[EngineClient] = None
+ try:
+ engine_client = AsyncLLMEngine.from_engine_args(
+ engine_args=engine_args,
+ usage_context=UsageContext.OPENAI_API_SERVER)
+ yield engine_client
+ finally:
+ if engine_client and hasattr(engine_client, "shutdown"):
+ engine_client.shutdown()
+
+ # MQLLMEngine.
+ else:
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
+ # Make TemporaryDirectory for prometheus multiprocessing
+ # Note: global TemporaryDirectory will be automatically
+ # cleaned up upon exit.
+ global prometheus_multiproc_dir
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
+ os.environ[
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
+ else:
+ logger.warning(
+ "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
+ "This directory must be wiped between vLLM runs or "
+ "you will find inaccurate metrics. Unset the variable "
+ "and vLLM will properly handle cleanup.")
+
+ # Select random path for IPC.
+ ipc_path = get_open_zmq_ipc_path()
+ logger.debug("Multiprocessing frontend to use %s for IPC Path.",
+ ipc_path)
+
+ # Start RPCServer in separate process (holds the LLMEngine).
+ # the current process might have CUDA context,
+ # so we need to spawn a new process
+ context = multiprocessing.get_context("spawn")
+
+ # The Process can raise an exception during startup, which may
+ # not actually result in an exitcode being reported. As a result
+ # we use a shared variable to communicate the information.
+ engine_alive = multiprocessing.Value('b', True, lock=False)
+ engine_process = context.Process(target=run_mp_engine,
+ args=(engine_args,
+ UsageContext.OPENAI_API_SERVER,
+ ipc_path, engine_alive))
+ engine_process.start()
+ engine_pid = engine_process.pid
+ assert engine_pid is not None, "Engine process failed to start."
+ logger.info("Started engine process with PID %d", engine_pid)
+
+ def _cleanup_ipc_path():
+ socket_path = ipc_path.replace("ipc://", "")
+ if os.path.exists(socket_path):
+ os.remove(socket_path)
+
+ # Ensure we clean up the local IPC socket file on exit.
+ atexit.register(_cleanup_ipc_path)
+
+ # Build RPCClient, which conforms to EngineClient Protocol.
+ engine_config = engine_args.create_engine_config()
+ build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
+ engine_pid)
+ mq_engine_client = await asyncio.get_running_loop().run_in_executor(
+ None, build_client)
+ try:
+ while True:
+ try:
+ await mq_engine_client.setup()
+ break
+ except TimeoutError:
+ if (not engine_process.is_alive()
+ or not engine_alive.value):
+ raise RuntimeError(
+ "Engine process failed to start. See stack "
+ "trace for the root cause.") from None
+
+ yield mq_engine_client # type: ignore[misc]
+ finally:
+ # Ensure rpc server process was terminated
+ engine_process.terminate()
+
+ # Close all open connections to the backend
+ mq_engine_client.close()
+
+ # Wait for engine process to join
+ engine_process.join(4)
+ if engine_process.exitcode is None:
+ # Kill if taking longer than 5 seconds to stop
+ engine_process.kill()
+
+ # Lazy import for prometheus multiprocessing.
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
+ # before prometheus_client is imported.
+ # See https://prometheus.github.io/client_python/multiprocess/
+ from prometheus_client import multiprocess
+ multiprocess.mark_process_dead(engine_process.pid)
+
+
+router = APIRouter()
+
+
+def mount_metrics(app: FastAPI):
+ # Lazy import for prometheus multiprocessing.
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
+ # before prometheus_client is imported.
+ # See https://prometheus.github.io/client_python/multiprocess/
+ from prometheus_client import (CollectorRegistry, make_asgi_app,
+ multiprocess)
+
+ prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
+ if prometheus_multiproc_dir_path is not None:
+ logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
+ prometheus_multiproc_dir_path)
+ registry = CollectorRegistry()
+ multiprocess.MultiProcessCollector(registry)
+
+ # Add prometheus asgi middleware to route /metrics requests
+ metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
+ else:
+ # Add prometheus asgi middleware to route /metrics requests
+ metrics_route = Mount("/metrics", make_asgi_app())
+
+ # Workaround for 307 Redirect for /metrics
+ metrics_route.path_regex = re.compile("^/metrics(?P.*)$")
+ app.routes.append(metrics_route)
+
+
+def base(request: Request) -> OpenAIServing:
+ # Reuse the existing instance
+ return tokenization(request)
+
+
+def models(request: Request) -> OpenAIServingModels:
+ return request.app.state.openai_serving_models
+
+
+def chat(request: Request) -> Optional[OpenAIServingChat]:
+ return request.app.state.openai_serving_chat
+
+
+def completion(request: Request) -> Optional[OpenAIServingCompletion]:
+ return request.app.state.openai_serving_completion
+
+
+def pooling(request: Request) -> Optional[OpenAIServingPooling]:
+ return request.app.state.openai_serving_pooling
+
+
+def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
+ return request.app.state.openai_serving_embedding
+
+
+def score(request: Request) -> Optional[OpenAIServingScores]:
+ return request.app.state.openai_serving_scores
+
+
+def rerank(request: Request) -> Optional[JinaAIServingRerank]:
+ return request.app.state.jinaai_serving_reranking
+
+
+def tokenization(request: Request) -> OpenAIServingTokenization:
+ return request.app.state.openai_serving_tokenization
+
+
+def engine_client(request: Request) -> EngineClient:
+ return request.app.state.engine_client
+
+
+@router.get("/health")
+async def health(raw_request: Request) -> Response:
+ """Health check."""
+ await engine_client(raw_request).check_health()
+ return Response(status_code=200)
+
+
+@router.api_route("/ping", methods=["GET", "POST"])
+async def ping(raw_request: Request) -> Response:
+ """Ping check. Endpoint required for SageMaker"""
+ return await health(raw_request)
+
+
+@router.post("/tokenize")
+@with_cancellation
+async def tokenize(request: TokenizeRequest, raw_request: Request):
+ handler = tokenization(raw_request)
+
+ generator = await handler.create_tokenize(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, TokenizeResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.post("/detokenize")
+@with_cancellation
+async def detokenize(request: DetokenizeRequest, raw_request: Request):
+ handler = tokenization(raw_request)
+
+ generator = await handler.create_detokenize(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, DetokenizeResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.get("/v1/models")
+async def show_available_models(raw_request: Request):
+ handler = models(raw_request)
+
+ models_ = await handler.show_available_models()
+ return JSONResponse(content=models_.model_dump())
+
+
+@router.get("/version")
+async def show_version():
+ ver = {"version": VLLM_VERSION}
+ return JSONResponse(content=ver)
+
+
+@router.post("/v1/chat/completions")
+@with_cancellation
+async def create_chat_completion(request: ChatCompletionRequest,
+ raw_request: Request):
+ handler = chat(raw_request)
+ if handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Chat Completions API")
+
+ generator = await handler.create_chat_completion(request, raw_request)
+
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+
+ elif isinstance(generator, ChatCompletionResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ return StreamingResponse(content=generator, media_type="text/event-stream")
+
+
+@router.post("/v1/completions")
+@with_cancellation
+async def create_completion(request: CompletionRequest, raw_request: Request):
+ handler = completion(raw_request)
+ if handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Completions API")
+
+ generator = await handler.create_completion(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, CompletionResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ return StreamingResponse(content=generator, media_type="text/event-stream")
+
+
+@router.post("/v1/embeddings")
+@with_cancellation
+async def create_embedding(request: EmbeddingRequest, raw_request: Request):
+ handler = embedding(raw_request)
+ if handler is None:
+ fallback_handler = pooling(raw_request)
+ if fallback_handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Embeddings API")
+
+ logger.warning(
+ "Embeddings API will become exclusive to embedding models "
+ "in a future release. To return the hidden states directly, "
+ "use the Pooling API (`/pooling`) instead.")
+
+ res = await fallback_handler.create_pooling(request, raw_request)
+
+ generator: Union[ErrorResponse, EmbeddingResponse]
+ if isinstance(res, PoolingResponse):
+ generator = EmbeddingResponse(
+ id=res.id,
+ object=res.object,
+ created=res.created,
+ model=res.model,
+ data=[
+ EmbeddingResponseData(
+ index=d.index,
+ embedding=d.data, # type: ignore
+ ) for d in res.data
+ ],
+ usage=res.usage,
+ )
+ else:
+ generator = res
+ else:
+ generator = await handler.create_embedding(request, raw_request)
+
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, EmbeddingResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.post("/pooling")
+@with_cancellation
+async def create_pooling(request: PoolingRequest, raw_request: Request):
+ handler = pooling(raw_request)
+ if handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Pooling API")
+
+ generator = await handler.create_pooling(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, PoolingResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.post("/score")
+@with_cancellation
+async def create_score(request: ScoreRequest, raw_request: Request):
+ handler = score(raw_request)
+ if handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Score API")
+
+ generator = await handler.create_score(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, ScoreResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.post("/v1/score")
+@with_cancellation
+async def create_score_v1(request: ScoreRequest, raw_request: Request):
+ logger.warning(
+ "To indicate that Score API is not part of standard OpenAI API, we "
+ "have moved it to `/score`. Please update your client accordingly.")
+
+ return await create_score(request, raw_request)
+
+
+@router.post("/rerank")
+@with_cancellation
+async def do_rerank(request: RerankRequest, raw_request: Request):
+ handler = rerank(raw_request)
+ if handler is None:
+ return base(raw_request).create_error_response(
+ message="The model does not support Rerank (Score) API")
+ generator = await handler.do_rerank(request, raw_request)
+ if isinstance(generator, ErrorResponse):
+ return JSONResponse(content=generator.model_dump(),
+ status_code=generator.code)
+ elif isinstance(generator, RerankResponse):
+ return JSONResponse(content=generator.model_dump())
+
+ assert_never(generator)
+
+
+@router.post("/v1/rerank")
+@with_cancellation
+async def do_rerank_v1(request: RerankRequest, raw_request: Request):
+ logger.warning_once(
+ "To indicate that the rerank API is not part of the standard OpenAI"
+ " API, we have located it at `/rerank`. Please update your client"
+ "accordingly. (Note: Conforms to JinaAI rerank API)")
+
+ return await do_rerank(request, raw_request)
+
+
+@router.post("/v2/rerank")
+@with_cancellation
+async def do_rerank_v2(request: RerankRequest, raw_request: Request):
+ return await do_rerank(request, raw_request)
+
+
+TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
+ "generate": {
+ "messages": (ChatCompletionRequest, create_chat_completion),
+ "default": (CompletionRequest, create_completion),
+ },
+ "embed": {
+ "messages": (EmbeddingChatRequest, create_embedding),
+ "default": (EmbeddingCompletionRequest, create_embedding),
+ },
+ "score": {
+ "default": (RerankRequest, do_rerank)
+ },
+ "rerank": {
+ "default": (RerankRequest, do_rerank)
+ },
+ "reward": {
+ "messages": (PoolingChatRequest, create_pooling),
+ "default": (PoolingCompletionRequest, create_pooling),
+ },
+ "classify": {
+ "messages": (PoolingChatRequest, create_pooling),
+ "default": (PoolingCompletionRequest, create_pooling),
+ },
+}
+
+if envs.VLLM_SERVER_DEV_MODE:
+
+ @router.post("/reset_prefix_cache")
+ async def reset_prefix_cache(raw_request: Request):
+ """
+ Reset the prefix cache. Note that we currently do not check if the
+ prefix cache is successfully reset in the API server.
+ """
+ logger.info("Resetting prefix cache...")
+ await engine_client(raw_request).reset_prefix_cache()
+ return Response(status_code=200)
+
+
+@router.post("/invocations")
+async def invocations(raw_request: Request):
+ """
+ For SageMaker, routes requests to other handlers based on model `task`.
+ """
+ body = await raw_request.json()
+ task = raw_request.app.state.task
+
+ if task not in TASK_HANDLERS:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unsupported task: '{task}' for '/invocations'. "
+ f"Expected one of {set(TASK_HANDLERS.keys())}")
+
+ handler_config = TASK_HANDLERS[task]
+ if "messages" in body:
+ request_model, handler = handler_config["messages"]
+ else:
+ request_model, handler = handler_config["default"]
+
+ # this is required since we lose the FastAPI automatic casting
+ request = request_model.model_validate(body)
+ return await handler(request, raw_request)
+
+
+if envs.VLLM_TORCH_PROFILER_DIR:
+ logger.warning(
+ "Torch Profiler is enabled in the API server. This should ONLY be "
+ "used for local development!")
+
+ @router.post("/start_profile")
+ async def start_profile(raw_request: Request):
+ logger.info("Starting profiler...")
+ await engine_client(raw_request).start_profile()
+ logger.info("Profiler started.")
+ return Response(status_code=200)
+
+ @router.post("/stop_profile")
+ async def stop_profile(raw_request: Request):
+ logger.info("Stopping profiler...")
+ await engine_client(raw_request).stop_profile()
+ logger.info("Profiler stopped.")
+ return Response(status_code=200)
+
+
+if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
+ logger.warning(
+ "Lora dynamic loading & unloading is enabled in the API server. "
+ "This should ONLY be used for local development!")
+
+ @router.post("/v1/load_lora_adapter")
+ async def load_lora_adapter(request: LoadLoraAdapterRequest,
+ raw_request: Request):
+ handler = models(raw_request)
+ response = await handler.load_lora_adapter(request)
+ if isinstance(response, ErrorResponse):
+ return JSONResponse(content=response.model_dump(),
+ status_code=response.code)
+
+ return Response(status_code=200, content=response)
+
+ @router.post("/v1/unload_lora_adapter")
+ async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
+ raw_request: Request):
+ handler = models(raw_request)
+ response = await handler.unload_lora_adapter(request)
+ if isinstance(response, ErrorResponse):
+ return JSONResponse(content=response.model_dump(),
+ status_code=response.code)
+
+ return Response(status_code=200, content=response)
+
+
+def build_app(args: Namespace) -> FastAPI:
+ if args.disable_fastapi_docs:
+ app = FastAPI(openapi_url=None,
+ docs_url=None,
+ redoc_url=None,
+ lifespan=lifespan)
+ else:
+ app = FastAPI(lifespan=lifespan)
+ app.include_router(router)
+ app.root_path = args.root_path
+
+ mount_metrics(app)
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=args.allowed_origins,
+ allow_credentials=args.allow_credentials,
+ allow_methods=args.allowed_methods,
+ allow_headers=args.allowed_headers,
+ )
+
+ @app.exception_handler(RequestValidationError)
+ async def validation_exception_handler(_, exc):
+ err = ErrorResponse(message=str(exc),
+ type="BadRequestError",
+ code=HTTPStatus.BAD_REQUEST)
+ return JSONResponse(err.model_dump(),
+ status_code=HTTPStatus.BAD_REQUEST)
+
+ if token := envs.VLLM_API_KEY or args.api_key:
+
+ @app.middleware("http")
+ async def authentication(request: Request, call_next):
+ if request.method == "OPTIONS":
+ return await call_next(request)
+ url_path = request.url.path
+ if app.root_path and url_path.startswith(app.root_path):
+ url_path = url_path[len(app.root_path):]
+ if not url_path.startswith("/v1"):
+ return await call_next(request)
+ if request.headers.get("Authorization") != "Bearer " + token:
+ return JSONResponse(content={"error": "Unauthorized"},
+ status_code=401)
+ return await call_next(request)
+
+ if args.enable_request_id_headers:
+ logger.warning(
+ "CAUTION: Enabling X-Request-Id headers in the API Server. "
+ "This can harm performance at high QPS.")
+
+ @app.middleware("http")
+ async def add_request_id(request: Request, call_next):
+ request_id = request.headers.get(
+ "X-Request-Id") or uuid.uuid4().hex
+ response = await call_next(request)
+ response.headers["X-Request-Id"] = request_id
+ return response
+
+ for middleware in args.middleware:
+ module_path, object_name = middleware.rsplit(".", 1)
+ imported = getattr(importlib.import_module(module_path), object_name)
+ if inspect.isclass(imported):
+ app.add_middleware(imported) # type: ignore[arg-type]
+ elif inspect.iscoroutinefunction(imported):
+ app.middleware("http")(imported)
+ else:
+ raise ValueError(f"Invalid middleware {middleware}. "
+ f"Must be a function or a class.")
+
+ return app
+
+
+async def init_app_state(
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ state: State,
+ args: Namespace,
+) -> None:
+ if args.served_model_name is not None:
+ served_model_names = args.served_model_name
+ else:
+ served_model_names = [args.model]
+
+ if args.disable_log_requests:
+ request_logger = None
+ else:
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
+
+ base_model_paths = [
+ BaseModelPath(name=name, model_path=args.model)
+ for name in served_model_names
+ ]
+
+ state.engine_client = engine_client
+ state.log_stats = not args.disable_log_stats
+
+ resolved_chat_template = load_chat_template(args.chat_template)
+ logger.info("Using supplied chat template:\n%s", resolved_chat_template)
+
+ state.openai_serving_models = OpenAIServingModels(
+ engine_client=engine_client,
+ model_config=model_config,
+ base_model_paths=base_model_paths,
+ lora_modules=args.lora_modules,
+ prompt_adapters=args.prompt_adapters,
+ )
+ await state.openai_serving_models.init_static_loras()
+ state.openai_serving_chat = OpenAIServingChat(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ args.response_role,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ enable_auto_tools=args.enable_auto_tool_choice,
+ tool_parser=args.tool_call_parser,
+ enable_reasoning=args.enable_reasoning,
+ reasoning_parser=args.reasoning_parser,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ ) if model_config.runner_type == "generate" else None
+ state.openai_serving_completion = OpenAIServingCompletion(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ ) if model_config.runner_type == "generate" else None
+ state.openai_serving_pooling = OpenAIServingPooling(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ ) if model_config.runner_type == "pooling" else None
+ state.openai_serving_embedding = OpenAIServingEmbedding(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ ) if model_config.task == "embed" else None
+ state.openai_serving_scores = OpenAIServingScores(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger
+ ) if model_config.task == "score" else None
+ state.jinaai_serving_reranking = JinaAIServingRerank(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger
+ ) if model_config.task == "score" else None
+ state.openai_serving_tokenization = OpenAIServingTokenization(
+ engine_client,
+ model_config,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ )
+ state.task = model_config.task
+
+
+def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
+ family = socket.AF_INET
+ if is_valid_ipv6_address(addr[0]):
+ family = socket.AF_INET6
+
+ sock = socket.socket(family=family, type=socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(addr)
+
+ return sock
+
+
+async def run_server(args, **uvicorn_kwargs) -> None:
+ logger.info("vLLM API server version %s", VLLM_VERSION)
+ logger.info("args: %s", args)
+
+ if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
+ ToolParserManager.import_tool_parser(args.tool_parser_plugin)
+
+ valid_tool_parses = ToolParserManager.tool_parsers.keys()
+ if args.enable_auto_tool_choice \
+ and args.tool_call_parser not in valid_tool_parses:
+ raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
+ f"(chose from {{ {','.join(valid_tool_parses)} }})")
+
+ valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
+ if args.enable_reasoning \
+ and args.reasoning_parser not in valid_reasoning_parses:
+ raise KeyError(
+ f"invalid reasoning parser: {args.reasoning_parser} "
+ f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
+
+ # workaround to make sure that we bind the port before the engine is set up.
+ # This avoids race conditions with ray.
+ # see https://github.com/vllm-project/vllm/issues/8204
+ sock_addr = (args.host or "", args.port)
+ sock = create_server_socket(sock_addr)
+
+ # workaround to avoid footguns where uvicorn drops requests with too
+ # many concurrent requests active
+ set_ulimit()
+
+ def signal_handler(*_) -> None:
+ # Interrupt server on sigterm while initializing
+ raise KeyboardInterrupt("terminated")
+
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ async with build_async_engine_client(args) as engine_client:
+ app = build_app(args)
+
+ model_config = await engine_client.get_model_config()
+ await init_app_state(engine_client, model_config, app.state, args)
+
+ shutdown_task = await serve_http(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level=args.uvicorn_log_level,
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+ ssl_keyfile=args.ssl_keyfile,
+ ssl_certfile=args.ssl_certfile,
+ ssl_ca_certs=args.ssl_ca_certs,
+ ssl_cert_reqs=args.ssl_cert_reqs,
+ # Workaround to work on macOS
+ fd=sock.fileno() if sys.platform.startswith("darwin") else None,
+ **uvicorn_kwargs,
+ )
+
+ # NB: Await server shutdown only after the backend context is exited
+ await shutdown_task
+
+ sock.close()
+
+
+if __name__ == "__main__":
+ # NOTE(simon):
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
+ parser = FlexibleArgumentParser(
+ description="vLLM OpenAI-Compatible RESTful API server.")
+ parser = make_arg_parser(parser)
+ args = parser.parse_args()
+ validate_parsed_serve_args(args)
+
+ uvloop.run(run_server(args))
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..3054958f3c8abc1e618c21a7fe260169efa6f23f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py
@@ -0,0 +1,305 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+This file contains the command line arguments for the vLLM's
+OpenAI-compatible server. It is kept in a separate file for documentation
+purposes.
+"""
+
+import argparse
+import json
+import ssl
+from typing import List, Optional, Sequence, Union, get_args
+
+from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
+from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
+ validate_chat_template)
+from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
+from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
+ PromptAdapterPath)
+from vllm.entrypoints.openai.tool_parsers import ToolParserManager
+from vllm.utils import FlexibleArgumentParser
+
+
+class LoRAParserAction(argparse.Action):
+
+ def __call__(
+ self,
+ parser: argparse.ArgumentParser,
+ namespace: argparse.Namespace,
+ values: Optional[Union[str, Sequence[str]]],
+ option_string: Optional[str] = None,
+ ):
+ if values is None:
+ values = []
+ if isinstance(values, str):
+ raise TypeError("Expected values to be a list")
+
+ lora_list: List[LoRAModulePath] = []
+ for item in values:
+ if item in [None, '']: # Skip if item is None or empty string
+ continue
+ if '=' in item and ',' not in item: # Old format: name=path
+ name, path = item.split('=')
+ lora_list.append(LoRAModulePath(name, path))
+ else: # Assume JSON format
+ try:
+ lora_dict = json.loads(item)
+ lora = LoRAModulePath(**lora_dict)
+ lora_list.append(lora)
+ except json.JSONDecodeError:
+ parser.error(
+ f"Invalid JSON format for --lora-modules: {item}")
+ except TypeError as e:
+ parser.error(
+ f"Invalid fields for --lora-modules: {item} - {str(e)}"
+ )
+ setattr(namespace, self.dest, lora_list)
+
+
+class PromptAdapterParserAction(argparse.Action):
+
+ def __call__(
+ self,
+ parser: argparse.ArgumentParser,
+ namespace: argparse.Namespace,
+ values: Optional[Union[str, Sequence[str]]],
+ option_string: Optional[str] = None,
+ ):
+ if values is None:
+ values = []
+ if isinstance(values, str):
+ raise TypeError("Expected values to be a list")
+
+ adapter_list: List[PromptAdapterPath] = []
+ for item in values:
+ name, path = item.split('=')
+ adapter_list.append(PromptAdapterPath(name, path))
+ setattr(namespace, self.dest, adapter_list)
+
+
+def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
+ parser.add_argument("--host",
+ type=nullable_str,
+ default=None,
+ help="Host name.")
+ parser.add_argument("--port", type=int, default=8000, help="Port number.")
+ parser.add_argument(
+ "--uvicorn-log-level",
+ type=str,
+ default="info",
+ choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
+ help="Log level for uvicorn.")
+ parser.add_argument("--allow-credentials",
+ action="store_true",
+ help="Allow credentials.")
+ parser.add_argument("--allowed-origins",
+ type=json.loads,
+ default=["*"],
+ help="Allowed origins.")
+ parser.add_argument("--allowed-methods",
+ type=json.loads,
+ default=["*"],
+ help="Allowed methods.")
+ parser.add_argument("--allowed-headers",
+ type=json.loads,
+ default=["*"],
+ help="Allowed headers.")
+ parser.add_argument("--api-key",
+ type=nullable_str,
+ default=None,
+ help="If provided, the server will require this key "
+ "to be presented in the header.")
+ parser.add_argument(
+ "--lora-modules",
+ type=nullable_str,
+ default=None,
+ nargs='+',
+ action=LoRAParserAction,
+ help="LoRA module configurations in either 'name=path' format"
+ "or JSON format. "
+ "Example (old format): ``'name=path'`` "
+ "Example (new format): "
+ "``{\"name\": \"name\", \"path\": \"lora_path\", "
+ "\"base_model_name\": \"id\"}``")
+ parser.add_argument(
+ "--prompt-adapters",
+ type=nullable_str,
+ default=None,
+ nargs='+',
+ action=PromptAdapterParserAction,
+ help="Prompt adapter configurations in the format name=path. "
+ "Multiple adapters can be specified.")
+ parser.add_argument("--chat-template",
+ type=nullable_str,
+ default=None,
+ help="The file path to the chat template, "
+ "or the template in single-line form "
+ "for the specified model.")
+ parser.add_argument(
+ '--chat-template-content-format',
+ type=str,
+ default="auto",
+ choices=get_args(ChatTemplateContentFormatOption),
+ help='The format to render message content within a chat template.'
+ '\n\n'
+ '* "string" will render the content as a string. '
+ 'Example: ``"Hello World"``\n'
+ '* "openai" will render the content as a list of dictionaries, '
+ 'similar to OpenAI schema. '
+ 'Example: ``[{"type": "text", "text": "Hello world!"}]``')
+ parser.add_argument("--response-role",
+ type=nullable_str,
+ default="assistant",
+ help="The role name to return if "
+ "``request.add_generation_prompt=true``.")
+ parser.add_argument("--ssl-keyfile",
+ type=nullable_str,
+ default=None,
+ help="The file path to the SSL key file.")
+ parser.add_argument("--ssl-certfile",
+ type=nullable_str,
+ default=None,
+ help="The file path to the SSL cert file.")
+ parser.add_argument("--ssl-ca-certs",
+ type=nullable_str,
+ default=None,
+ help="The CA certificates file.")
+ parser.add_argument(
+ "--ssl-cert-reqs",
+ type=int,
+ default=int(ssl.CERT_NONE),
+ help="Whether client certificate is required (see stdlib ssl module's)."
+ )
+ parser.add_argument(
+ "--root-path",
+ type=nullable_str,
+ default=None,
+ help="FastAPI root_path when app is behind a path based routing proxy."
+ )
+ parser.add_argument(
+ "--middleware",
+ type=nullable_str,
+ action="append",
+ default=[],
+ help="Additional ASGI middleware to apply to the app. "
+ "We accept multiple --middleware arguments. "
+ "The value should be an import path. "
+ "If a function is provided, vLLM will add it to the server "
+ "using ``@app.middleware('http')``. "
+ "If a class is provided, vLLM will add it to the server "
+ "using ``app.add_middleware()``. ")
+ parser.add_argument(
+ "--return-tokens-as-token-ids",
+ action="store_true",
+ help="When ``--max-logprobs`` is specified, represents single tokens "
+ " as strings of the form 'token_id:{token_id}' so that tokens "
+ "that are not JSON-encodable can be identified.")
+ parser.add_argument(
+ "--disable-frontend-multiprocessing",
+ action="store_true",
+ help="If specified, will run the OpenAI frontend server in the same "
+ "process as the model serving engine.")
+ parser.add_argument(
+ "--enable-request-id-headers",
+ action="store_true",
+ help="If specified, API server will add X-Request-Id header to "
+ "responses. Caution: this hurts performance at high QPS.")
+ parser.add_argument(
+ "--enable-auto-tool-choice",
+ action="store_true",
+ default=False,
+ help="Enable auto tool choice for supported models. Use "
+ "``--tool-call-parser`` to specify which parser to use.")
+ parser.add_argument(
+ "--enable-reasoning",
+ action="store_true",
+ default=False,
+ help="Whether to enable reasoning_content for the model. "
+ "If enabled, the model will be able to generate reasoning content.")
+
+ valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
+ parser.add_argument(
+ "--reasoning-parser",
+ type=str,
+ metavar="{" + ",".join(valid_reasoning_parsers) + "}",
+ default=None,
+ help=
+ "Select the reasoning parser depending on the model that you're using."
+ " This is used to parse the reasoning content into OpenAI API "
+ "format. Required for ``--enable-reasoning``.")
+
+ valid_tool_parsers = ToolParserManager.tool_parsers.keys()
+ parser.add_argument(
+ "--tool-call-parser",
+ type=str,
+ metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
+ "--tool-parser-plugin",
+ default=None,
+ help=
+ "Select the tool call parser depending on the model that you're using."
+ " This is used to parse the model-generated tool call into OpenAI API "
+ "format. Required for ``--enable-auto-tool-choice``.")
+
+ parser.add_argument(
+ "--tool-parser-plugin",
+ type=str,
+ default="",
+ help=
+ "Special the tool parser plugin write to parse the model-generated tool"
+ " into OpenAI API format, the name register in this plugin can be used "
+ "in ``--tool-call-parser``.")
+
+ parser = AsyncEngineArgs.add_cli_args(parser)
+
+ parser.add_argument('--max-log-len',
+ type=int,
+ default=None,
+ help='Max number of prompt characters or prompt '
+ 'ID numbers being printed in log.'
+ '\n\nDefault: Unlimited')
+
+ parser.add_argument(
+ "--disable-fastapi-docs",
+ action='store_true',
+ default=False,
+ help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."
+ )
+ parser.add_argument(
+ "--enable-prompt-tokens-details",
+ action='store_true',
+ default=False,
+ help="If set to True, enable prompt_tokens_details in usage.")
+
+ return parser
+
+
+def validate_parsed_serve_args(args: argparse.Namespace):
+ """Quick checks for model serve args that raise prior to loading."""
+ if hasattr(args, "subparser") and args.subparser != "serve":
+ return
+
+ # Ensure that the chat template is valid; raises if it likely isn't
+ validate_chat_template(args.chat_template)
+
+ # Enable auto tool needs a tool call parser to be valid
+ if args.enable_auto_tool_choice and not args.tool_call_parser:
+ raise TypeError("Error: --enable-auto-tool-choice requires "
+ "--tool-call-parser")
+
+ # Enable reasoning needs a reasoning parser to be valid
+ if args.enable_reasoning and not args.reasoning_parser:
+ raise TypeError("Error: --enable-reasoning requires "
+ "--reasoning-parser")
+
+ # Ref https://api-docs.deepseek.com/guides/reasoning_model
+ # tool call and reasoning cannot be enabled at the same time.
+ if args.enable_auto_tool_choice and args.enable_reasoning:
+ raise TypeError(
+ "Error: --enable-auto-tool-choice and "
+ "--enable-reasoning cannot be enabled at the same time")
+
+
+def create_parser_for_docs() -> FlexibleArgumentParser:
+ parser_for_docs = FlexibleArgumentParser(
+ prog="-m vllm.entrypoints.openai.api_server")
+ return make_arg_parser(parser_for_docs)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e5eef40eaf82ce3afa59523a2ed24796338abe
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from functools import lru_cache, partial
+from typing import Dict, FrozenSet, Iterable, List, Optional, Union
+
+import torch
+
+from vllm.sampling_params import LogitsProcessor
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+
+class AllowedTokenIdsLogitsProcessor:
+ """Logits processor for constraining generated tokens to a
+ specific set of token ids."""
+
+ def __init__(self, allowed_ids: Iterable[int]):
+ self.allowed_ids: Optional[List[int]] = list(allowed_ids)
+ self.mask: Optional[torch.Tensor] = None
+
+ def __call__(self, token_ids: List[int],
+ logits: torch.Tensor) -> torch.Tensor:
+ if self.mask is None:
+ self.mask = torch.ones((logits.shape[-1], ),
+ dtype=torch.bool,
+ device=logits.device)
+ self.mask[self.allowed_ids] = False
+ self.allowed_ids = None
+ logits.masked_fill_(self.mask, float("-inf"))
+ return logits
+
+
+@lru_cache(maxsize=32)
+def _get_allowed_token_ids_logits_processor(
+ allowed_token_ids: FrozenSet[int],
+ vocab_size: int,
+) -> LogitsProcessor:
+ if not allowed_token_ids:
+ raise ValueError("Empty allowed_token_ids provided")
+ if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
+ raise ValueError("allowed_token_ids contains "
+ "out-of-vocab token id")
+ return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
+
+
+def logit_bias_logits_processor(
+ logit_bias: Dict[int, float],
+ token_ids: List[int],
+ logits: torch.Tensor,
+) -> torch.Tensor:
+ for token_id, bias in logit_bias.items():
+ logits[token_id] += bias
+ return logits
+
+
+def get_logits_processors(
+ logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
+ allowed_token_ids: Optional[List[int]],
+ tokenizer: AnyTokenizer,
+) -> List[LogitsProcessor]:
+ logits_processors: List[LogitsProcessor] = []
+ if logit_bias:
+ try:
+ # Convert token_id to integer
+ # Clamp the bias between -100 and 100 per OpenAI API spec
+ clamped_logit_bias: Dict[int, float] = {
+ int(token_id): min(100.0, max(-100.0, bias))
+ for token_id, bias in logit_bias.items()
+ }
+ except ValueError as exc:
+ raise ValueError(
+ "Found token_id in logit_bias that is not "
+ "an integer or string representing an integer") from exc
+
+ # Check if token_id is within the vocab size
+ for token_id, bias in clamped_logit_bias.items():
+ if token_id < 0 or token_id >= len(tokenizer):
+ raise ValueError(f"token_id {token_id} in logit_bias contains "
+ "out-of-vocab token id")
+
+ logits_processors.append(
+ partial(logit_bias_logits_processor, clamped_logit_bias))
+
+ if allowed_token_ids is not None:
+ logits_processors.append(
+ _get_allowed_token_ids_logits_processor(
+ frozenset(allowed_token_ids), len(tokenizer)))
+
+ return logits_processors
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py
new file mode 100644
index 0000000000000000000000000000000000000000..83b841826231ef17c2426f73adfda256942f41ce
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py
@@ -0,0 +1,1428 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
+import re
+import time
+from argparse import Namespace
+from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
+
+import torch
+from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
+ ValidationInfo, field_validator, model_validator)
+from typing_extensions import Annotated
+
+from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
+from vllm.logger import init_logger
+from vllm.pooling_params import PoolingParams
+from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
+ RequestOutputKind, SamplingParams)
+from vllm.sequence import Logprob
+from vllm.utils import random_uuid, resolve_obj_by_qualname
+
+logger = init_logger(__name__)
+
+# torch is mocked during docs generation,
+# so we have to provide the values as literals
+_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
+_LONG_INFO: Union["torch.iinfo", Namespace]
+
+try:
+ from sphinx.ext.autodoc.mock import _MockModule
+
+ if isinstance(torch, _MockModule):
+ _LONG_INFO = _MOCK_LONG_INFO
+ else:
+ _LONG_INFO = torch.iinfo(torch.long)
+except ModuleNotFoundError:
+ _LONG_INFO = torch.iinfo(torch.long)
+
+assert _LONG_INFO.min == _MOCK_LONG_INFO.min
+assert _LONG_INFO.max == _MOCK_LONG_INFO.max
+
+
+class OpenAIBaseModel(BaseModel):
+ # OpenAI API does allow extra fields
+ model_config = ConfigDict(extra="allow")
+
+ # Cache class field names
+ field_names: ClassVar[Optional[Set[str]]] = None
+
+ @model_validator(mode="wrap")
+ @classmethod
+ def __log_extra_fields__(cls, data, handler):
+ result = handler(data)
+ if not isinstance(data, dict):
+ return result
+ field_names = cls.field_names
+ if field_names is None:
+ # Get all class field names and their potential aliases
+ field_names = set()
+ for field_name, field in cls.model_fields.items():
+ field_names.add(field_name)
+ if alias := getattr(field, 'alias', None):
+ field_names.add(alias)
+ cls.field_names = field_names
+
+ # Compare against both field names and aliases
+ if any(k not in field_names for k in data):
+ logger.warning(
+ "The following fields were present in the request "
+ "but ignored: %s",
+ data.keys() - field_names)
+ return result
+
+
+class ErrorResponse(OpenAIBaseModel):
+ object: str = "error"
+ message: str
+ type: str
+ param: Optional[str] = None
+ code: int
+
+
+class ModelPermission(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
+ object: str = "model_permission"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ allow_create_engine: bool = False
+ allow_sampling: bool = True
+ allow_logprobs: bool = True
+ allow_search_indices: bool = False
+ allow_view: bool = True
+ allow_fine_tuning: bool = False
+ organization: str = "*"
+ group: Optional[str] = None
+ is_blocking: bool = False
+
+
+class ModelCard(OpenAIBaseModel):
+ id: str
+ object: str = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: str = "vllm"
+ root: Optional[str] = None
+ parent: Optional[str] = None
+ max_model_len: Optional[int] = None
+ permission: List[ModelPermission] = Field(default_factory=list)
+
+
+class ModelList(OpenAIBaseModel):
+ object: str = "list"
+ data: List[ModelCard] = Field(default_factory=list)
+
+
+class PromptTokenUsageInfo(OpenAIBaseModel):
+ cached_tokens: Optional[int] = None
+
+
+class UsageInfo(OpenAIBaseModel):
+ prompt_tokens: int = 0
+ total_tokens: int = 0
+ completion_tokens: Optional[int] = 0
+ prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
+
+
+class RequestResponseMetadata(BaseModel):
+ request_id: str
+ final_usage_info: Optional[UsageInfo] = None
+
+
+class JsonSchemaResponseFormat(OpenAIBaseModel):
+ name: str
+ description: Optional[str] = None
+ # schema is the field in openai but that causes conflicts with pydantic so
+ # instead use json_schema with an alias
+ json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
+ strict: Optional[bool] = None
+
+
+class ResponseFormat(OpenAIBaseModel):
+ # type must be "json_schema", "json_object" or "text"
+ type: Literal["text", "json_object", "json_schema"]
+ json_schema: Optional[JsonSchemaResponseFormat] = None
+
+
+class StreamOptions(OpenAIBaseModel):
+ include_usage: Optional[bool] = True
+ continuous_usage_stats: Optional[bool] = False
+
+
+class FunctionDefinition(OpenAIBaseModel):
+ name: str
+ description: Optional[str] = None
+ parameters: Optional[Dict[str, Any]] = None
+
+
+class ChatCompletionToolsParam(OpenAIBaseModel):
+ type: Literal["function"] = "function"
+ function: FunctionDefinition
+
+
+class ChatCompletionNamedFunction(OpenAIBaseModel):
+ name: str
+
+
+class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
+ function: ChatCompletionNamedFunction
+ type: Literal["function"] = "function"
+
+
+class LogitsProcessorConstructor(BaseModel):
+ qualname: str
+ args: Optional[List[Any]] = None
+ kwargs: Optional[Dict[str, Any]] = None
+
+
+LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
+
+
+def get_logits_processors(processors: Optional[LogitsProcessors],
+ pattern: Optional[str]) -> Optional[List[Any]]:
+ if processors and pattern:
+ logits_processors = []
+ for processor in processors:
+ qualname = processor if isinstance(processor,
+ str) else processor.qualname
+ if not re.match(pattern, qualname):
+ raise ValueError(
+ f"Logits processor '{qualname}' is not allowed by this "
+ "server. See --logits-processor-pattern engine argument "
+ "for more information.")
+ try:
+ logits_processor = resolve_obj_by_qualname(qualname)
+ except Exception as e:
+ raise ValueError(
+ f"Logits processor '{qualname}' could not be resolved: {e}"
+ ) from e
+ if isinstance(processor, LogitsProcessorConstructor):
+ logits_processor = logits_processor(*processor.args or [],
+ **processor.kwargs or {})
+ logits_processors.append(logits_processor)
+ return logits_processors
+ elif processors:
+ raise ValueError(
+ "The `logits_processors` argument is not supported by this "
+ "server. See --logits-processor-pattern engine argugment "
+ "for more information.")
+ return None
+
+
+class ChatCompletionRequest(OpenAIBaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/chat/create
+ messages: List[ChatCompletionMessageParam]
+ model: str
+ frequency_penalty: Optional[float] = 0.0
+ logit_bias: Optional[Dict[str, float]] = None
+ logprobs: Optional[bool] = False
+ top_logprobs: Optional[int] = 0
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
+ max_tokens: Optional[int] = Field(
+ default=None,
+ deprecated=
+ 'max_tokens is deprecated in favor of the max_completion_tokens field')
+ max_completion_tokens: Optional[int] = None
+ n: Optional[int] = 1
+ presence_penalty: Optional[float] = 0.0
+ response_format: Optional[ResponseFormat] = None
+ seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
+ stream: Optional[bool] = False
+ stream_options: Optional[StreamOptions] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ tools: Optional[List[ChatCompletionToolsParam]] = None
+ tool_choice: Optional[Union[Literal["none"], Literal["auto"],
+ ChatCompletionNamedToolChoiceParam]] = "none"
+
+ # NOTE this will be ignored by VLLM -- the model determines the behavior
+ parallel_tool_calls: Optional[bool] = False
+ user: Optional[str] = None
+
+ # doc: begin-chat-completion-sampling-params
+ best_of: Optional[int] = None
+ use_beam_search: bool = False
+ top_k: Optional[int] = None
+ min_p: Optional[float] = None
+ repetition_penalty: Optional[float] = None
+ length_penalty: float = 1.0
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
+ include_stop_str_in_output: bool = False
+ ignore_eos: bool = False
+ min_tokens: int = 0
+ skip_special_tokens: bool = True
+ spaces_between_special_tokens: bool = True
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+ prompt_logprobs: Optional[int] = None
+ # doc: end-chat-completion-sampling-params
+
+ # doc: begin-chat-completion-extra-params
+ echo: bool = Field(
+ default=False,
+ description=(
+ "If true, the new message will be prepended with the last message "
+ "if they belong to the same role."),
+ )
+ add_generation_prompt: bool = Field(
+ default=True,
+ description=
+ ("If true, the generation prompt will be added to the chat template. "
+ "This is a parameter used by chat template in tokenizer config of the "
+ "model."),
+ )
+ continue_final_message: bool = Field(
+ default=False,
+ description=
+ ("If this is set, the chat will be formatted so that the final "
+ "message in the chat is open-ended, without any EOS tokens. The "
+ "model will continue this message rather than starting a new one. "
+ "This allows you to \"prefill\" part of the model's response for it. "
+ "Cannot be used at the same time as `add_generation_prompt`."),
+ )
+ add_special_tokens: bool = Field(
+ default=False,
+ description=(
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
+ "on top of what is added by the chat template. "
+ "For most models, the chat template takes care of adding the "
+ "special tokens so this should be set to false (as is the "
+ "default)."),
+ )
+ documents: Optional[List[Dict[str, str]]] = Field(
+ default=None,
+ description=
+ ("A list of dicts representing documents that will be accessible to "
+ "the model if it is performing RAG (retrieval-augmented generation)."
+ " If the template does not support RAG, this argument will have no "
+ "effect. We recommend that each document should be a dict containing "
+ "\"title\" and \"text\" keys."),
+ )
+ chat_template: Optional[str] = Field(
+ default=None,
+ description=(
+ "A Jinja template to use for this conversion. "
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one."),
+ )
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description=("Additional kwargs to pass to the template renderer. "
+ "Will be accessible by the chat template."),
+ )
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
+ default=None,
+ description=("If specified, the output will follow the JSON schema."),
+ )
+ guided_regex: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, the output will follow the regex pattern."),
+ )
+ guided_choice: Optional[List[str]] = Field(
+ default=None,
+ description=(
+ "If specified, the output will be exactly one of the choices."),
+ )
+ guided_grammar: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, the output will follow the context free grammar."),
+ )
+ guided_decoding_backend: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default guided decoding backend "
+ "of the server for this specific request. If set, must be either "
+ "'outlines' / 'lm-format-enforcer'"))
+ guided_whitespace_pattern: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default whitespace pattern "
+ "for guided json decoding."))
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+ request_id: str = Field(
+ default_factory=lambda: f"{random_uuid()}",
+ description=(
+ "The request_id related to this request. If the caller does "
+ "not set it, a random_uuid will be generated. This id is used "
+ "through out the inference process and return in response."))
+ logits_processors: Optional[LogitsProcessors] = Field(
+ default=None,
+ description=(
+ "A list of either qualified names of logits processors, or "
+ "constructor objects, to apply when sampling. A constructor is "
+ "a JSON object with a required 'qualname' field specifying the "
+ "qualified name of the processor class/factory, and optional "
+ "'args' and 'kwargs' fields containing positional and keyword "
+ "arguments. For example: {'qualname': "
+ "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
+ "{'param': 'value'}}."))
+
+ # doc: end-chat-completion-extra-params
+
+ # Default sampling parameters for chat completion requests
+ _DEFAULT_SAMPLING_PARAMS: dict = {
+ "repetition_penalty": 1.0,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "min_p": 0.0,
+ }
+
+ def to_beam_search_params(
+ self,
+ default_max_tokens: int,
+ default_sampling_params: Optional[dict] = None
+ ) -> BeamSearchParams:
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
+ max_tokens = self.max_completion_tokens or self.max_tokens
+
+ if default_sampling_params is None:
+ default_sampling_params = {}
+ n = self.n if self.n is not None else 1
+
+ # Use minimum of context window, user request & server limit.
+ max_tokens = min(
+ val for val in (default_max_tokens, max_tokens,
+ default_sampling_params.get("max_tokens", None))
+ if val is not None)
+
+ if (temperature := self.temperature) is None:
+ temperature = default_sampling_params.get(
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
+
+ return BeamSearchParams(
+ beam_width=n,
+ max_tokens=max_tokens,
+ ignore_eos=self.ignore_eos,
+ temperature=temperature,
+ length_penalty=self.length_penalty,
+ include_stop_str_in_output=self.include_stop_str_in_output)
+
+ def to_sampling_params(
+ self,
+ default_max_tokens: int,
+ logits_processor_pattern: Optional[str],
+ default_sampling_params: Optional[dict] = None) -> SamplingParams:
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
+ max_tokens = self.max_completion_tokens or self.max_tokens
+
+ if default_sampling_params is None:
+ default_sampling_params = {}
+
+ # Use minimum of context window, user request & server limit.
+ max_tokens = min(
+ val for val in (default_max_tokens, max_tokens,
+ default_sampling_params.get("max_tokens", None))
+ if val is not None)
+
+ # Default parameters
+ if (repetition_penalty := self.repetition_penalty) is None:
+ repetition_penalty = default_sampling_params.get(
+ "repetition_penalty",
+ self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
+ )
+ if (temperature := self.temperature) is None:
+ temperature = default_sampling_params.get(
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
+ if (top_p := self.top_p) is None:
+ top_p = default_sampling_params.get(
+ "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
+ if (top_k := self.top_k) is None:
+ top_k = default_sampling_params.get(
+ "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
+ if (min_p := self.min_p) is None:
+ min_p = default_sampling_params.get(
+ "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
+
+ prompt_logprobs = self.prompt_logprobs
+ if prompt_logprobs is None and self.echo:
+ prompt_logprobs = self.top_logprobs
+
+ guided_json_object = None
+ if self.response_format is not None:
+ if self.response_format.type == "json_object":
+ guided_json_object = True
+ elif self.response_format.type == "json_schema":
+ json_schema = self.response_format.json_schema
+ assert json_schema is not None
+ self.guided_json = json_schema.json_schema
+ if self.guided_decoding_backend is None:
+ self.guided_decoding_backend = "xgrammar"
+
+ guided_decoding = GuidedDecodingParams.from_optional(
+ json=self._get_guided_json_from_tool() or self.guided_json,
+ regex=self.guided_regex,
+ choice=self.guided_choice,
+ grammar=self.guided_grammar,
+ json_object=guided_json_object,
+ backend=self.guided_decoding_backend,
+ whitespace_pattern=self.guided_whitespace_pattern)
+
+ return SamplingParams.from_optional(
+ n=self.n,
+ best_of=self.best_of,
+ presence_penalty=self.presence_penalty,
+ frequency_penalty=self.frequency_penalty,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
+ seed=self.seed,
+ stop=self.stop,
+ stop_token_ids=self.stop_token_ids,
+ logprobs=self.top_logprobs if self.logprobs else None,
+ prompt_logprobs=prompt_logprobs,
+ ignore_eos=self.ignore_eos,
+ max_tokens=max_tokens,
+ min_tokens=self.min_tokens,
+ skip_special_tokens=self.skip_special_tokens,
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
+ logits_processors=get_logits_processors(self.logits_processors,
+ logits_processor_pattern),
+ include_stop_str_in_output=self.include_stop_str_in_output,
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
+ output_kind=RequestOutputKind.DELTA if self.stream \
+ else RequestOutputKind.FINAL_ONLY,
+ guided_decoding=guided_decoding,
+ logit_bias=self.logit_bias)
+
+ def _get_guided_json_from_tool(
+ self) -> Optional[Union[str, dict, BaseModel]]:
+ # user has chosen to not use any tool
+ if self.tool_choice == "none" or self.tools is None:
+ return None
+
+ # user has chosen to use a named tool
+ if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
+ tool_name = self.tool_choice.function.name
+ tools = {tool.function.name: tool.function for tool in self.tools}
+ if tool_name not in tools:
+ raise ValueError(
+ f"Tool '{tool_name}' has not been passed in `tools`.")
+ tool = tools[tool_name]
+ return tool.parameters
+
+ return None
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_stream_options(cls, data):
+ if data.get("stream_options") and not data.get("stream"):
+ raise ValueError(
+ "Stream options can only be defined when `stream=True`.")
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_logprobs(cls, data):
+ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
+ if data.get("stream") and prompt_logprobs > 0:
+ raise ValueError(
+ "`prompt_logprobs` are not available when `stream=True`.")
+
+ if prompt_logprobs < 0:
+ raise ValueError("`prompt_logprobs` must be a positive value.")
+
+ if (top_logprobs := data.get("top_logprobs")) is not None:
+ if top_logprobs < 0:
+ raise ValueError("`top_logprobs` must be a positive value.")
+
+ if not data.get("logprobs"):
+ raise ValueError(
+ "when using `top_logprobs`, `logprobs` must be set to true."
+ )
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_guided_decoding_count(cls, data):
+ if isinstance(data, ValueError):
+ raise data
+
+ guide_count = sum([
+ "guided_json" in data and data["guided_json"] is not None,
+ "guided_regex" in data and data["guided_regex"] is not None,
+ "guided_choice" in data and data["guided_choice"] is not None
+ ])
+ # you can only use one kind of guided decoding
+ if guide_count > 1:
+ raise ValueError(
+ "You can only use one kind of guided decoding "
+ "('guided_json', 'guided_regex' or 'guided_choice').")
+ # you can only either use guided decoding or tools, not both
+ if guide_count > 1 and data.get("tool_choice",
+ "none") not in ("none", "auto"):
+ raise ValueError(
+ "You can only either use guided decoding or tools, not both.")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_tool_usage(cls, data):
+
+ # if "tool_choice" is not specified but tools are provided,
+ # default to "auto" tool_choice
+ if "tool_choice" not in data and data.get("tools"):
+ data["tool_choice"] = "auto"
+
+ # if "tool_choice" is "none" -- ignore tools if present
+ if "tool_choice" in data and data["tool_choice"] == "none":
+ # ensure that no tools are present
+ data.pop("tools", None)
+ return data
+
+ # if "tool_choice" is specified -- validation
+ if "tool_choice" in data:
+
+ # ensure that if "tool choice" is specified, tools are present
+ if "tools" not in data or data["tools"] is None:
+ raise ValueError(
+ "When using `tool_choice`, `tools` must be set.")
+
+ # make sure that tool choice is either a named tool
+ # OR that it's set to "auto"
+ if data["tool_choice"] != "auto" and not isinstance(
+ data["tool_choice"], dict):
+ raise ValueError(
+ "`tool_choice` must either be a named tool, \"auto\", "
+ "or \"none\".")
+
+ # ensure that if "tool_choice" is specified as an object,
+ # it matches a valid tool
+ if isinstance(data["tool_choice"], dict):
+ valid_tool = False
+ specified_function = data["tool_choice"].get("function")
+ if not specified_function:
+ raise ValueError(
+ "Expected field `function` in `tool_choice`."
+ " Correct usage: `{\"type\": \"function\","
+ " \"function\": {\"name\": \"my_function\"}}`")
+ specified_function_name = specified_function.get("name")
+ if not specified_function_name:
+ raise ValueError(
+ "Expected field `name` in `function` in `tool_choice`."
+ "Correct usage: `{\"type\": \"function\", "
+ "\"function\": {\"name\": \"my_function\"}}`")
+ for tool in data["tools"]:
+ if tool["function"]["name"] == specified_function_name:
+ valid_tool = True
+ break
+ if not valid_tool:
+ raise ValueError(
+ "The tool specified in `tool_choice` does not match any"
+ " of the specified `tools`")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_generation_prompt(cls, data):
+ if data.get("continue_final_message") and data.get(
+ "add_generation_prompt"):
+ raise ValueError("Cannot set both `continue_final_message` and "
+ "`add_generation_prompt` to True.")
+ return data
+
+
+class CompletionRequest(OpenAIBaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/completions/create
+ model: str
+ prompt: Union[List[int], List[List[int]], str, List[str]]
+ best_of: Optional[int] = None
+ echo: Optional[bool] = False
+ frequency_penalty: Optional[float] = 0.0
+ logit_bias: Optional[Dict[str, float]] = None
+ logprobs: Optional[int] = None
+ max_tokens: Optional[int] = 16
+ n: int = 1
+ presence_penalty: Optional[float] = 0.0
+ seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
+ stream: Optional[bool] = False
+ stream_options: Optional[StreamOptions] = None
+ suffix: Optional[str] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ user: Optional[str] = None
+
+ # doc: begin-completion-sampling-params
+ use_beam_search: bool = False
+ top_k: Optional[int] = None
+ min_p: Optional[float] = None
+ repetition_penalty: Optional[float] = None
+ length_penalty: float = 1.0
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
+ include_stop_str_in_output: bool = False
+ ignore_eos: bool = False
+ min_tokens: int = 0
+ skip_special_tokens: bool = True
+ spaces_between_special_tokens: bool = True
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+ allowed_token_ids: Optional[List[int]] = None
+ prompt_logprobs: Optional[int] = None
+ # doc: end-completion-sampling-params
+
+ # doc: begin-completion-extra-params
+ add_special_tokens: bool = Field(
+ default=True,
+ description=(
+ "If true (the default), special tokens (e.g. BOS) will be added to "
+ "the prompt."),
+ )
+ response_format: Optional[ResponseFormat] = Field(
+ default=None,
+ description=
+ ("Similar to chat completion, this parameter specifies the format of "
+ "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
+ "{'type': 'text' } is supported."),
+ )
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
+ default=None,
+ description="If specified, the output will follow the JSON schema.",
+ )
+ guided_regex: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, the output will follow the regex pattern."),
+ )
+ guided_choice: Optional[List[str]] = Field(
+ default=None,
+ description=(
+ "If specified, the output will be exactly one of the choices."),
+ )
+ guided_grammar: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, the output will follow the context free grammar."),
+ )
+ guided_decoding_backend: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default guided decoding backend "
+ "of the server for this specific request. If set, must be one of "
+ "'outlines' / 'lm-format-enforcer'"))
+ guided_whitespace_pattern: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default whitespace pattern "
+ "for guided json decoding."))
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+ logits_processors: Optional[LogitsProcessors] = Field(
+ default=None,
+ description=(
+ "A list of either qualified names of logits processors, or "
+ "constructor objects, to apply when sampling. A constructor is "
+ "a JSON object with a required 'qualname' field specifying the "
+ "qualified name of the processor class/factory, and optional "
+ "'args' and 'kwargs' fields containing positional and keyword "
+ "arguments. For example: {'qualname': "
+ "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
+ "{'param': 'value'}}."))
+
+ # doc: end-completion-extra-params
+
+ # Default sampling parameters for completion requests
+ _DEFAULT_SAMPLING_PARAMS: dict = {
+ "repetition_penalty": 1.0,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "min_p": 0.0,
+ }
+
+ def to_beam_search_params(
+ self,
+ default_max_tokens: int,
+ default_sampling_params: Optional[dict] = None
+ ) -> BeamSearchParams:
+ max_tokens = self.max_tokens
+
+ if default_sampling_params is None:
+ default_sampling_params = {}
+ n = self.n if self.n is not None else 1
+
+ # Use minimum of context window, user request & server limit.
+ max_tokens = min(
+ val for val in (default_max_tokens, max_tokens,
+ default_sampling_params.get("max_tokens", None))
+ if val is not None)
+
+ if (temperature := self.temperature) is None:
+ temperature = default_sampling_params.get("temperature", 1.0)
+
+ return BeamSearchParams(
+ beam_width=n,
+ max_tokens=max_tokens,
+ ignore_eos=self.ignore_eos,
+ temperature=temperature,
+ length_penalty=self.length_penalty,
+ include_stop_str_in_output=self.include_stop_str_in_output)
+
+ def to_sampling_params(
+ self,
+ default_max_tokens: int,
+ logits_processor_pattern: Optional[str],
+ default_sampling_params: Optional[dict] = None) -> SamplingParams:
+ max_tokens = self.max_tokens
+
+ if default_sampling_params is None:
+ default_sampling_params = {}
+
+ # Use minimum of context window, user request & server limit.
+ max_tokens = min(
+ val for val in (default_max_tokens, max_tokens,
+ default_sampling_params.get("max_tokens", None))
+ if val is not None)
+
+ # Default parameters
+ if (repetition_penalty := self.repetition_penalty) is None:
+ repetition_penalty = default_sampling_params.get(
+ "repetition_penalty",
+ self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
+ )
+ if (temperature := self.temperature) is None:
+ temperature = default_sampling_params.get(
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
+ if (top_p := self.top_p) is None:
+ top_p = default_sampling_params.get(
+ "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
+ if (top_k := self.top_k) is None:
+ top_k = default_sampling_params.get(
+ "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
+ if (min_p := self.min_p) is None:
+ min_p = default_sampling_params.get(
+ "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
+
+ prompt_logprobs = self.prompt_logprobs
+ if prompt_logprobs is None and self.echo:
+ prompt_logprobs = self.logprobs
+
+ echo_without_generation = self.echo and self.max_tokens == 0
+
+ guided_json_object = None
+ if (self.response_format is not None
+ and self.response_format.type == "json_object"):
+ guided_json_object = True
+
+ guided_decoding = GuidedDecodingParams.from_optional(
+ json=self.guided_json,
+ regex=self.guided_regex,
+ choice=self.guided_choice,
+ grammar=self.guided_grammar,
+ json_object=guided_json_object,
+ backend=self.guided_decoding_backend,
+ whitespace_pattern=self.guided_whitespace_pattern)
+
+ return SamplingParams.from_optional(
+ n=self.n,
+ best_of=self.best_of,
+ presence_penalty=self.presence_penalty,
+ frequency_penalty=self.frequency_penalty,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
+ seed=self.seed,
+ stop=self.stop,
+ stop_token_ids=self.stop_token_ids,
+ logprobs=self.logprobs,
+ ignore_eos=self.ignore_eos,
+ max_tokens=max_tokens if not echo_without_generation else 1,
+ min_tokens=self.min_tokens,
+ prompt_logprobs=prompt_logprobs,
+ skip_special_tokens=self.skip_special_tokens,
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
+ include_stop_str_in_output=self.include_stop_str_in_output,
+ logits_processors=get_logits_processors(self.logits_processors,
+ logits_processor_pattern),
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
+ output_kind=RequestOutputKind.DELTA if self.stream \
+ else RequestOutputKind.FINAL_ONLY,
+ guided_decoding=guided_decoding,
+ logit_bias=self.logit_bias,
+ allowed_token_ids=self.allowed_token_ids)
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_guided_decoding_count(cls, data):
+ guide_count = sum([
+ "guided_json" in data and data["guided_json"] is not None,
+ "guided_regex" in data and data["guided_regex"] is not None,
+ "guided_choice" in data and data["guided_choice"] is not None
+ ])
+ if guide_count > 1:
+ raise ValueError(
+ "You can only use one kind of guided decoding "
+ "('guided_json', 'guided_regex' or 'guided_choice').")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_logprobs(cls, data):
+ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
+ if data.get("stream") and prompt_logprobs > 0:
+ raise ValueError(
+ "`prompt_logprobs` are not available when `stream=True`.")
+
+ if prompt_logprobs < 0:
+ raise ValueError("`prompt_logprobs` must be a positive value.")
+
+ if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
+ raise ValueError("`logprobs` must be a positive value.")
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_stream_options(cls, data):
+ if data.get("stream_options") and not data.get("stream"):
+ raise ValueError(
+ "Stream options can only be defined when `stream=True`.")
+
+ return data
+
+
+class EmbeddingCompletionRequest(OpenAIBaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/embeddings
+ model: str
+ input: Union[List[int], List[List[int]], str, List[str]]
+ encoding_format: Literal["float", "base64"] = "float"
+ dimensions: Optional[int] = None
+ user: Optional[str] = None
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+
+ # doc: begin-embedding-pooling-params
+ additional_data: Optional[Any] = None
+ # doc: end-embedding-pooling-params
+
+ # doc: begin-embedding-extra-params
+ add_special_tokens: bool = Field(
+ default=True,
+ description=(
+ "If true (the default), special tokens (e.g. BOS) will be added to "
+ "the prompt."),
+ )
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+
+ # doc: end-embedding-extra-params
+
+ def to_pooling_params(self):
+ return PoolingParams(additional_data=self.additional_data)
+
+
+class EmbeddingChatRequest(OpenAIBaseModel):
+ model: str
+ messages: List[ChatCompletionMessageParam]
+
+ encoding_format: Literal["float", "base64"] = "float"
+ dimensions: Optional[int] = None
+ user: Optional[str] = None
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+
+ # doc: begin-chat-embedding-pooling-params
+ additional_data: Optional[Any] = None
+ # doc: end-chat-embedding-pooling-params
+
+ # doc: begin-chat-embedding-extra-params
+ add_special_tokens: bool = Field(
+ default=False,
+ description=(
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
+ "on top of what is added by the chat template. "
+ "For most models, the chat template takes care of adding the "
+ "special tokens so this should be set to false (as is the "
+ "default)."),
+ )
+ chat_template: Optional[str] = Field(
+ default=None,
+ description=(
+ "A Jinja template to use for this conversion. "
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one."),
+ )
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description=("Additional kwargs to pass to the template renderer. "
+ "Will be accessible by the chat template."),
+ )
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+ # doc: end-chat-embedding-extra-params
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_generation_prompt(cls, data):
+ if data.get("continue_final_message") and data.get(
+ "add_generation_prompt"):
+ raise ValueError("Cannot set both `continue_final_message` and "
+ "`add_generation_prompt` to True.")
+ return data
+
+ def to_pooling_params(self):
+ return PoolingParams(additional_data=self.additional_data)
+
+
+EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
+
+PoolingCompletionRequest = EmbeddingCompletionRequest
+PoolingChatRequest = EmbeddingChatRequest
+PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
+
+
+class ScoreRequest(OpenAIBaseModel):
+ model: str
+ text_1: Union[List[str], str]
+ text_2: Union[List[str], str]
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+
+ # doc: begin-score-pooling-params
+ additional_data: Optional[Any] = None
+ # doc: end-score-pooling-params
+
+ # doc: begin-score-extra-params
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+
+ # doc: end-score-extra-params
+
+ def to_pooling_params(self):
+ return PoolingParams(additional_data=self.additional_data)
+
+
+class RerankRequest(OpenAIBaseModel):
+ model: str
+ query: str
+ documents: List[str]
+ top_n: int = Field(default_factory=lambda: 0)
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+
+ # doc: begin-rerank-pooling-params
+ additional_data: Optional[Any] = None
+ # doc: end-rerank-pooling-params
+
+ # doc: begin-rerank-extra-params
+ priority: int = Field(
+ default=0,
+ description=(
+ "The priority of the request (lower means earlier handling; "
+ "default: 0). Any priority other than 0 will raise an error "
+ "if the served model does not use priority scheduling."))
+
+ # doc: end-rerank-extra-params
+
+ def to_pooling_params(self):
+ return PoolingParams(additional_data=self.additional_data)
+
+
+class RerankDocument(BaseModel):
+ text: str
+
+
+class RerankResult(BaseModel):
+ index: int
+ document: RerankDocument
+ relevance_score: float
+
+
+class RerankUsage(BaseModel):
+ total_tokens: int
+
+
+class RerankResponse(OpenAIBaseModel):
+ id: str
+ model: str
+ usage: RerankUsage
+ results: List[RerankResult]
+
+
+class CompletionLogProbs(OpenAIBaseModel):
+ text_offset: List[int] = Field(default_factory=list)
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
+ tokens: List[str] = Field(default_factory=list)
+ top_logprobs: List[Optional[Dict[str,
+ float]]] = Field(default_factory=list)
+
+
+class CompletionResponseChoice(OpenAIBaseModel):
+ index: int
+ text: str
+ logprobs: Optional[CompletionLogProbs] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = Field(
+ default=None,
+ description=(
+ "The stop string or token id that caused the completion "
+ "to stop, None if the completion finished for some other reason "
+ "including encountering the EOS token"),
+ )
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
+
+
+class CompletionResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseChoice]
+ usage: UsageInfo
+
+
+class CompletionResponseStreamChoice(OpenAIBaseModel):
+ index: int
+ text: str
+ logprobs: Optional[CompletionLogProbs] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = Field(
+ default=None,
+ description=(
+ "The stop string or token id that caused the completion "
+ "to stop, None if the completion finished for some other reason "
+ "including encountering the EOS token"),
+ )
+
+
+class CompletionStreamResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseStreamChoice]
+ usage: Optional[UsageInfo] = Field(default=None)
+
+
+class EmbeddingResponseData(OpenAIBaseModel):
+ index: int
+ object: str = "embedding"
+ embedding: Union[List[float], str]
+
+
+class EmbeddingResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
+ object: str = "list"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ data: List[EmbeddingResponseData]
+ usage: UsageInfo
+
+
+class PoolingResponseData(OpenAIBaseModel):
+ index: int
+ object: str = "pooling"
+ data: Union[List[List[float]], List[float], str]
+
+
+class PoolingResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
+ object: str = "list"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ data: List[PoolingResponseData]
+ usage: UsageInfo
+
+
+class ScoreResponseData(OpenAIBaseModel):
+ index: int
+ object: str = "score"
+ score: float
+
+
+class ScoreResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
+ object: str = "list"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ data: List[ScoreResponseData]
+ usage: UsageInfo
+
+
+class FunctionCall(OpenAIBaseModel):
+ name: str
+ arguments: str
+
+
+class ToolCall(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+ type: Literal["function"] = "function"
+ function: FunctionCall
+
+
+class DeltaFunctionCall(BaseModel):
+ name: Optional[str] = None
+ arguments: Optional[str] = None
+
+
+# a tool call delta where everything is optional
+class DeltaToolCall(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+ type: Literal["function"] = "function"
+ index: int
+ function: Optional[DeltaFunctionCall] = None
+
+
+class ExtractedToolCallInformation(BaseModel):
+ # indicate if tools were called
+ tools_called: bool
+
+ # extracted tool calls
+ tool_calls: List[ToolCall]
+
+ # content - per OpenAI spec, content AND tool calls can be returned rarely
+ # But some models will do this intentionally
+ content: Optional[str] = None
+
+
+class ChatMessage(OpenAIBaseModel):
+ role: str
+ reasoning_content: Optional[str] = None
+ content: Optional[str] = None
+ tool_calls: List[ToolCall] = Field(default_factory=list)
+
+
+class ChatCompletionLogProb(OpenAIBaseModel):
+ token: str
+ logprob: float = -9999.0
+ bytes: Optional[List[int]] = None
+
+
+class ChatCompletionLogProbsContent(ChatCompletionLogProb):
+ top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
+
+
+class ChatCompletionLogProbs(OpenAIBaseModel):
+ content: Optional[List[ChatCompletionLogProbsContent]] = None
+
+
+class ChatCompletionResponseChoice(OpenAIBaseModel):
+ index: int
+ message: ChatMessage
+ logprobs: Optional[ChatCompletionLogProbs] = None
+ # per OpenAI spec this is the default
+ finish_reason: Optional[str] = "stop"
+ # not part of the OpenAI spec but included in vLLM for legacy reasons
+ stop_reason: Optional[Union[int, str]] = None
+
+
+class ChatCompletionResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
+ object: Literal["chat.completion"] = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: UsageInfo
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
+
+
+class DeltaMessage(OpenAIBaseModel):
+ role: Optional[str] = None
+ content: Optional[str] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: List[DeltaToolCall] = Field(default_factory=list)
+
+
+class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
+ index: int
+ delta: DeltaMessage
+ logprobs: Optional[ChatCompletionLogProbs] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = None
+
+
+class ChatCompletionStreamResponse(OpenAIBaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseStreamChoice]
+ usage: Optional[UsageInfo] = Field(default=None)
+
+
+class BatchRequestInput(OpenAIBaseModel):
+ """
+ The per-line object of the batch input file.
+
+ NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
+ """
+
+ # A developer-provided per-request id that will be used to match outputs to
+ # inputs. Must be unique for each request in a batch.
+ custom_id: str
+
+ # The HTTP method to be used for the request. Currently only POST is
+ # supported.
+ method: str
+
+ # The OpenAI API relative URL to be used for the request. Currently
+ # /v1/chat/completions is supported.
+ url: str
+
+ # The parameters of the request.
+ body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
+
+ @field_validator('body', mode='plain')
+ @classmethod
+ def check_type_for_url(cls, value: Any, info: ValidationInfo):
+ # Use url to disambiguate models
+ url = info.data['url']
+ if url == "/v1/chat/completions":
+ return ChatCompletionRequest.model_validate(value)
+ if url == "/v1/embeddings":
+ return TypeAdapter(EmbeddingRequest).validate_python(value)
+ if url == "/v1/score":
+ return ScoreRequest.model_validate(value)
+ return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
+ ScoreRequest]).validate_python(value)
+
+
+class BatchResponseData(OpenAIBaseModel):
+ # HTTP status code of the response.
+ status_code: int = 200
+
+ # An unique identifier for the API request.
+ request_id: str
+
+ # The body of the response.
+ body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
+ ScoreResponse]] = None
+
+
+class BatchRequestOutput(OpenAIBaseModel):
+ """
+ The per-line object of the batch output and error files
+ """
+
+ id: str
+
+ # A developer-provided per-request id that will be used to match outputs to
+ # inputs.
+ custom_id: str
+
+ response: Optional[BatchResponseData]
+
+ # For requests that failed with a non-HTTP error, this will contain more
+ # information on the cause of the failure.
+ error: Optional[Any]
+
+
+class TokenizeCompletionRequest(OpenAIBaseModel):
+ model: str
+ prompt: str
+
+ add_special_tokens: bool = Field(
+ default=True,
+ description=(
+ "If true (the default), special tokens (e.g. BOS) will be added to "
+ "the prompt."),
+ )
+
+
+class TokenizeChatRequest(OpenAIBaseModel):
+ model: str
+ messages: List[ChatCompletionMessageParam]
+
+ add_generation_prompt: bool = Field(
+ default=True,
+ description=
+ ("If true, the generation prompt will be added to the chat template. "
+ "This is a parameter used by chat template in tokenizer config of the "
+ "model."),
+ )
+ continue_final_message: bool = Field(
+ default=False,
+ description=
+ ("If this is set, the chat will be formatted so that the final "
+ "message in the chat is open-ended, without any EOS tokens. The "
+ "model will continue this message rather than starting a new one. "
+ "This allows you to \"prefill\" part of the model's response for it. "
+ "Cannot be used at the same time as `add_generation_prompt`."),
+ )
+ add_special_tokens: bool = Field(
+ default=False,
+ description=(
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
+ "on top of what is added by the chat template. "
+ "For most models, the chat template takes care of adding the "
+ "special tokens so this should be set to false (as is the "
+ "default)."),
+ )
+ chat_template: Optional[str] = Field(
+ default=None,
+ description=(
+ "A Jinja template to use for this conversion. "
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one."),
+ )
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description=("Additional kwargs to pass to the template renderer. "
+ "Will be accessible by the chat template."),
+ )
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_generation_prompt(cls, data):
+ if data.get("continue_final_message") and data.get(
+ "add_generation_prompt"):
+ raise ValueError("Cannot set both `continue_final_message` and "
+ "`add_generation_prompt` to True.")
+ return data
+
+
+TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
+
+
+class TokenizeResponse(OpenAIBaseModel):
+ count: int
+ max_model_len: int
+ tokens: List[int]
+
+
+class DetokenizeRequest(OpenAIBaseModel):
+ model: str
+ tokens: List[int]
+
+
+class DetokenizeResponse(OpenAIBaseModel):
+ prompt: str
+
+
+class LoadLoraAdapterRequest(BaseModel):
+ lora_name: str
+ lora_path: str
+
+
+class UnloadLoraAdapterRequest(BaseModel):
+ lora_name: str
+ lora_int_id: Optional[int] = Field(default=None)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..675d3cdcf97155073c07d8afbdd5e5878b23ba78
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py
@@ -0,0 +1,342 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+from http import HTTPStatus
+from io import StringIO
+from typing import Awaitable, Callable, List, Optional
+
+import aiohttp
+import torch
+from prometheus_client import start_http_server
+from tqdm import tqdm
+
+from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.entrypoints.logger import RequestLogger, logger
+# yapf: disable
+from vllm.entrypoints.openai.protocol import (BatchRequestInput,
+ BatchRequestOutput,
+ BatchResponseData,
+ ChatCompletionResponse,
+ EmbeddingResponse, ErrorResponse,
+ ScoreResponse)
+# yapf: enable
+from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
+from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
+from vllm.entrypoints.openai.serving_models import (BaseModelPath,
+ OpenAIServingModels)
+from vllm.entrypoints.openai.serving_score import OpenAIServingScores
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import FlexibleArgumentParser, random_uuid
+from vllm.version import __version__ as VLLM_VERSION
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(
+ description="vLLM OpenAI-Compatible batch runner.")
+ parser.add_argument(
+ "-i",
+ "--input-file",
+ required=True,
+ type=str,
+ help=
+ "The path or url to a single input file. Currently supports local file "
+ "paths, or the http protocol (http or https). If a URL is specified, "
+ "the file should be available via HTTP GET.")
+ parser.add_argument(
+ "-o",
+ "--output-file",
+ required=True,
+ type=str,
+ help="The path or url to a single output file. Currently supports "
+ "local file paths, or web (http or https) urls. If a URL is specified,"
+ " the file should be available via HTTP PUT.")
+ parser.add_argument("--response-role",
+ type=nullable_str,
+ default="assistant",
+ help="The role name to return if "
+ "`request.add_generation_prompt=True`.")
+
+ parser = AsyncEngineArgs.add_cli_args(parser)
+
+ parser.add_argument('--max-log-len',
+ type=int,
+ default=None,
+ help='Max number of prompt characters or prompt '
+ 'ID numbers being printed in log.'
+ '\n\nDefault: Unlimited')
+
+ parser.add_argument("--enable-metrics",
+ action="store_true",
+ help="Enable Prometheus metrics")
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="0.0.0.0",
+ help="URL to the Prometheus metrics server "
+ "(only needed if enable-metrics is set).",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8000,
+ help="Port number for the Prometheus metrics server "
+ "(only needed if enable-metrics is set).",
+ )
+ parser.add_argument(
+ "--enable-prompt-tokens-details",
+ action='store_true',
+ default=False,
+ help="If set to True, enable prompt_tokens_details in usage.")
+
+ return parser.parse_args()
+
+
+# explicitly use pure text format, with a newline at the end
+# this makes it impossible to see the animation in the progress bar
+# but will avoid messing up with ray or multiprocessing, which wraps
+# each line of output with some prefix.
+_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
+
+
+class BatchProgressTracker:
+
+ def __init__(self):
+ self._total = 0
+ self._pbar: Optional[tqdm] = None
+
+ def submitted(self):
+ self._total += 1
+
+ def completed(self):
+ if self._pbar:
+ self._pbar.update()
+
+ def pbar(self) -> tqdm:
+ enable_tqdm = not torch.distributed.is_initialized(
+ ) or torch.distributed.get_rank() == 0
+ self._pbar = tqdm(total=self._total,
+ unit="req",
+ desc="Running batch",
+ mininterval=5,
+ disable=not enable_tqdm,
+ bar_format=_BAR_FORMAT)
+ return self._pbar
+
+
+async def read_file(path_or_url: str) -> str:
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
+ async with aiohttp.ClientSession() as session, \
+ session.get(path_or_url) as resp:
+ return await resp.text()
+ else:
+ with open(path_or_url, encoding="utf-8") as f:
+ return f.read()
+
+
+async def write_file(path_or_url: str, data: str) -> None:
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
+ async with aiohttp.ClientSession() as session, \
+ session.put(path_or_url, data=data.encode("utf-8")):
+ pass
+ else:
+ # We should make this async, but as long as this is always run as a
+ # standalone program, blocking the event loop won't effect performance
+ # in this particular case.
+ with open(path_or_url, "w", encoding="utf-8") as f:
+ f.write(data)
+
+
+def make_error_request_output(request: BatchRequestInput,
+ error_msg: str) -> BatchRequestOutput:
+ batch_output = BatchRequestOutput(
+ id=f"vllm-{random_uuid()}",
+ custom_id=request.custom_id,
+ response=BatchResponseData(
+ status_code=HTTPStatus.BAD_REQUEST,
+ request_id=f"vllm-batch-{random_uuid()}",
+ ),
+ error=error_msg,
+ )
+ return batch_output
+
+
+async def make_async_error_request_output(
+ request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
+ return make_error_request_output(request, error_msg)
+
+
+async def run_request(serving_engine_func: Callable,
+ request: BatchRequestInput,
+ tracker: BatchProgressTracker) -> BatchRequestOutput:
+ response = await serving_engine_func(request.body)
+
+ if isinstance(response,
+ (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
+ batch_output = BatchRequestOutput(
+ id=f"vllm-{random_uuid()}",
+ custom_id=request.custom_id,
+ response=BatchResponseData(
+ body=response, request_id=f"vllm-batch-{random_uuid()}"),
+ error=None,
+ )
+ elif isinstance(response, ErrorResponse):
+ batch_output = BatchRequestOutput(
+ id=f"vllm-{random_uuid()}",
+ custom_id=request.custom_id,
+ response=BatchResponseData(
+ status_code=response.code,
+ request_id=f"vllm-batch-{random_uuid()}"),
+ error=response,
+ )
+ else:
+ batch_output = make_error_request_output(
+ request, error_msg="Request must not be sent in stream mode")
+
+ tracker.completed()
+ return batch_output
+
+
+async def main(args):
+ if args.served_model_name is not None:
+ served_model_names = args.served_model_name
+ else:
+ served_model_names = [args.model]
+
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ engine = AsyncLLMEngine.from_engine_args(
+ engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
+
+ model_config = await engine.get_model_config()
+ base_model_paths = [
+ BaseModelPath(name=name, model_path=args.model)
+ for name in served_model_names
+ ]
+
+ if args.disable_log_requests:
+ request_logger = None
+ else:
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
+
+ # Create the openai serving objects.
+ openai_serving_models = OpenAIServingModels(
+ engine_client=engine,
+ model_config=model_config,
+ base_model_paths=base_model_paths,
+ lora_modules=None,
+ prompt_adapters=None,
+ )
+ openai_serving_chat = OpenAIServingChat(
+ engine,
+ model_config,
+ openai_serving_models,
+ args.response_role,
+ request_logger=request_logger,
+ chat_template=None,
+ chat_template_content_format="auto",
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ ) if model_config.runner_type == "generate" else None
+ openai_serving_embedding = OpenAIServingEmbedding(
+ engine,
+ model_config,
+ openai_serving_models,
+ request_logger=request_logger,
+ chat_template=None,
+ chat_template_content_format="auto",
+ ) if model_config.task == "embed" else None
+ openai_serving_scores = (OpenAIServingScores(
+ engine,
+ model_config,
+ openai_serving_models,
+ request_logger=request_logger,
+ ) if model_config.task == "score" else None)
+
+ tracker = BatchProgressTracker()
+ logger.info("Reading batch from %s...", args.input_file)
+
+ # Submit all requests in the file to the engine "concurrently".
+ response_futures: List[Awaitable[BatchRequestOutput]] = []
+ for request_json in (await read_file(args.input_file)).strip().split("\n"):
+ # Skip empty lines.
+ request_json = request_json.strip()
+ if not request_json:
+ continue
+
+ request = BatchRequestInput.model_validate_json(request_json)
+
+ # Determine the type of request and run it.
+ if request.url == "/v1/chat/completions":
+ handler_fn = (None if openai_serving_chat is None else
+ openai_serving_chat.create_chat_completion)
+ if handler_fn is None:
+ response_futures.append(
+ make_async_error_request_output(
+ request,
+ error_msg=
+ "The model does not support Chat Completions API",
+ ))
+ continue
+
+ response_futures.append(run_request(handler_fn, request, tracker))
+ tracker.submitted()
+ elif request.url == "/v1/embeddings":
+ handler_fn = (None if openai_serving_embedding is None else
+ openai_serving_embedding.create_embedding)
+ if handler_fn is None:
+ response_futures.append(
+ make_async_error_request_output(
+ request,
+ error_msg="The model does not support Embeddings API",
+ ))
+ continue
+
+ response_futures.append(run_request(handler_fn, request, tracker))
+ tracker.submitted()
+ elif request.url == "/v1/score":
+ handler_fn = (None if openai_serving_scores is None else
+ openai_serving_scores.create_score)
+ if handler_fn is None:
+ response_futures.append(
+ make_async_error_request_output(
+ request,
+ error_msg="The model does not support Scores API",
+ ))
+ continue
+
+ response_futures.append(run_request(handler_fn, request, tracker))
+ tracker.submitted()
+ else:
+ response_futures.append(
+ make_async_error_request_output(
+ request,
+ error_msg=
+ "Only /v1/chat/completions, /v1/embeddings, and /v1/score "
+ "are supported in the batch endpoint.",
+ ))
+
+ with tracker.pbar():
+ responses = await asyncio.gather(*response_futures)
+
+ output_buffer = StringIO()
+ for response in responses:
+ print(response.model_dump_json(), file=output_buffer)
+
+ output_buffer.seek(0)
+ await write_file(args.output_file, output_buffer.read().strip())
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ logger.info("vLLM batch processing API version %s", VLLM_VERSION)
+ logger.info("args: %s", args)
+
+ # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
+ # to publish metrics at the /metrics endpoint.
+ if args.enable_metrics:
+ logger.info("Prometheus metrics enabled")
+ start_http_server(port=args.port, addr=args.url)
+ else:
+ logger.info("Prometheus metrics disabled")
+
+ asyncio.run(main(args))
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..107220d548afc0adf506fb24c6b1a10ea2a3b232
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py
@@ -0,0 +1,955 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import json
+import time
+from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
+ Optional)
+from typing import Sequence as GenericSequence
+from typing import Union
+
+from fastapi import Request
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
+ ConversationMessage)
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionLogProb, ChatCompletionLogProbs,
+ ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
+ ChatCompletionRequest, ChatCompletionResponse,
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
+ RequestResponseMetadata, ToolCall, UsageInfo)
+from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
+ ReasoningParserManager)
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
+from vllm.logger import init_logger
+from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.sampling_params import BeamSearchParams, SamplingParams
+from vllm.sequence import Logprob
+from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
+from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
+
+logger = init_logger(__name__)
+
+
+class OpenAIServingChat(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ response_role: str,
+ *,
+ request_logger: Optional[RequestLogger],
+ chat_template: Optional[str],
+ chat_template_content_format: ChatTemplateContentFormatOption,
+ return_tokens_as_token_ids: bool = False,
+ enable_reasoning: bool = False,
+ reasoning_parser: Optional[str] = None,
+ enable_auto_tools: bool = False,
+ tool_parser: Optional[str] = None,
+ enable_prompt_tokens_details: bool = False,
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger,
+ return_tokens_as_token_ids=return_tokens_as_token_ids)
+
+ self.response_role = response_role
+ self.chat_template = chat_template
+ self.chat_template_content_format: Final = chat_template_content_format
+
+ # set up tool use
+ self.enable_auto_tools: bool = enable_auto_tools
+ if self.enable_auto_tools:
+ logger.info(
+ "\"auto\" tool choice has been enabled please note that while"
+ " the parallel_tool_calls client option is preset for "
+ "compatibility reasons, it will be ignored.")
+
+ self.enable_reasoning: bool = enable_reasoning
+ self.reasoning_parser: Optional[Callable[[AnyTokenizer],
+ ReasoningParser]] = None
+ if self.enable_reasoning:
+ try:
+ self.reasoning_parser = (
+ ReasoningParserManager.get_reasoning_parser(
+ reasoning_parser))
+ except Exception as e:
+ raise TypeError("Error: --enable-reasoning requires "
+ f"reasoning_parser:'{reasoning_parser}' "
+ "which has not been registered") from e
+ self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
+ if self.enable_auto_tools:
+ try:
+ if (tool_parser == "pythonic" and
+ model_config.model.startswith("meta-llama/Llama-3.2")):
+ logger.warning(
+ "Llama3.2 models may struggle to emit valid pythonic"
+ " tool calls")
+ self.tool_parser = ToolParserManager.get_tool_parser(
+ tool_parser)
+ except Exception as e:
+ raise TypeError("Error: --enable-auto-tool-choice requires "
+ f"tool_parser:'{tool_parser}' which has not "
+ "been registered") from e
+
+ self.enable_prompt_tokens_details = enable_prompt_tokens_details
+ diff_sampling_param = self.model_config.get_diff_sampling_param()
+ if diff_sampling_param:
+ logger.info("Overwriting default chat sampling param with: %s",
+ diff_sampling_param)
+
+ async def create_chat_completion(
+ self,
+ request: ChatCompletionRequest,
+ raw_request: Optional[Request] = None,
+ ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
+ ErrorResponse]:
+ """
+ Chat Completion API similar to OpenAI's API.
+
+ See https://platform.openai.com/docs/api-reference/chat/create
+ for the API specification. This API mimics the OpenAI
+ Chat Completion API.
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ logger.error("Error with model %s", error_check_ret)
+ return error_check_ret
+
+ # If the engine is dead, raise the engine's DEAD_ERROR.
+ # This is required for the streaming case, where we return a
+ # success status before we actually start generating text :).
+ if self.engine_client.errored:
+ raise self.engine_client.dead_error
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ model_name = self.models.model_name(lora_request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ tool_parser = self.tool_parser
+
+ # validation for OpenAI tools
+ # tool_choice = "required" is not supported
+ if request.tool_choice == "required":
+ return self.create_error_response(
+ "tool_choice = \"required\" is not supported!")
+
+ # because of issues with pydantic we need to potentially
+ # re-serialize the tool_calls field of the request
+ # for more info: see comment in `maybe_serialize_tool_calls`
+ if isinstance(tokenizer, MistralTokenizer):
+ maybe_serialize_tool_calls(request)
+
+ if (request.tool_choice == "auto" and
+ not (self.enable_auto_tools and tool_parser is not None)
+ and not isinstance(tokenizer, MistralTokenizer)):
+ # for hf tokenizers, "auto" tools requires
+ # --enable-auto-tool-choice and --tool-call-parser
+ return self.create_error_response(
+ "\"auto\" tool choice requires "
+ "--enable-auto-tool-choice and --tool-call-parser to be set"
+ )
+
+ tool_dicts = None if request.tools is None else [
+ tool.model_dump() for tool in request.tools
+ ]
+
+ (
+ conversation,
+ request_prompts,
+ engine_prompts,
+ ) = await self._preprocess_chat(
+ request,
+ tokenizer,
+ request.messages,
+ chat_template=request.chat_template or self.chat_template,
+ chat_template_content_format=self.chat_template_content_format,
+ add_generation_prompt=request.add_generation_prompt,
+ continue_final_message=request.continue_final_message,
+ tool_dicts=tool_dicts,
+ documents=request.documents,
+ chat_template_kwargs=request.chat_template_kwargs,
+ tool_parser=tool_parser,
+ truncate_prompt_tokens=request.truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ request_id = "chatcmpl-" \
+ f"{self._base_request_id(raw_request, request.request_id)}"
+
+ request_metadata = RequestResponseMetadata(request_id=request_id)
+ if raw_request:
+ raw_request.state.request_metadata = request_metadata
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[RequestOutput, None]] = []
+ try:
+ for i, engine_prompt in enumerate(engine_prompts):
+ sampling_params: Union[SamplingParams, BeamSearchParams]
+ default_max_tokens = self.max_model_len - len(
+ engine_prompt["prompt_token_ids"])
+ # Build default sampling params
+ default_sampling_params = (
+ self.model_config.get_diff_sampling_param())
+ if request.use_beam_search:
+ sampling_params = request.to_beam_search_params(
+ default_max_tokens, default_sampling_params)
+ else:
+ sampling_params = request.to_sampling_params(
+ default_max_tokens,
+ self.model_config.logits_processor_pattern,
+ default_sampling_params)
+
+ self._log_inputs(request_id,
+ request_prompts[i],
+ params=sampling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ if isinstance(sampling_params, BeamSearchParams):
+ generator = self.engine_client.beam_search(
+ prompt=engine_prompt,
+ request_id=request_id,
+ params=sampling_params,
+ )
+ else:
+ generator = self.engine_client.generate(
+ engine_prompt,
+ sampling_params,
+ request_id,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ prompt_adapter_request=prompt_adapter_request,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ assert len(generators) == 1
+ result_generator, = generators
+
+ # Streaming response
+ if request.stream:
+ return self.chat_completion_stream_generator(
+ request, result_generator, request_id, model_name,
+ conversation, tokenizer, request_metadata)
+
+ try:
+ return await self.chat_completion_full_generator(
+ request, result_generator, request_id, model_name,
+ conversation, tokenizer, request_metadata)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
+ if request.add_generation_prompt:
+ return self.response_role
+ return request.messages[-1]["role"]
+
+ async def chat_completion_stream_generator(
+ self,
+ request: ChatCompletionRequest,
+ result_generator: AsyncIterator[RequestOutput],
+ request_id: str,
+ model_name: str,
+ conversation: List[ConversationMessage],
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+ ) -> AsyncGenerator[str, None]:
+ created_time = int(time.time())
+ chunk_object_type: Final = "chat.completion.chunk"
+ first_iteration = True
+
+ # Send response for each token for each request.n (index)
+ num_choices = 1 if request.n is None else request.n
+ previous_num_tokens = [0] * num_choices
+ finish_reason_sent = [False] * num_choices
+ num_prompt_tokens = 0
+ num_cached_tokens = None
+
+ if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
+ tool_choice_function_name = request.tool_choice.function.name
+ else:
+ tool_choice_function_name = None
+
+ # Determine whether tools are in use with "auto" tool choice
+ tool_choice_auto = (
+ not tool_choice_function_name
+ and self._should_stream_with_auto_tool_parsing(request))
+
+ should_stream_with_reasoning_parsing = (
+ self._should_stream_with_reasoning_parsing(request))
+
+ all_previous_token_ids: Optional[List[List[int]]]
+
+ # Only one of these will be used, thus previous_texts and
+ # all_previous_token_ids will not be used twice in the same iteration.
+ if tool_choice_auto or should_stream_with_reasoning_parsing:
+ # These are only required in "auto" tool choice case
+ previous_texts = [""] * num_choices
+ all_previous_token_ids = [[]] * num_choices
+ else:
+ previous_texts, all_previous_token_ids = None, None
+
+ try:
+ # There is no need to check if the reasoning_parser is None
+ # because the should_stream_with_reasoning_parsing check
+ # already ensures that the reasoning_parser is not None.
+ # but the pre-commit hook requires it.
+ if should_stream_with_reasoning_parsing and \
+ self.reasoning_parser is not None:
+ reasoning_parser = self.reasoning_parser(tokenizer)
+ except RuntimeError as e:
+ logger.exception("Error in reasoning parser creation.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ # Prepare the tool parser if it's needed
+ try:
+ if tool_choice_auto and self.tool_parser:
+ tool_parsers: List[Optional[ToolParser]] = [
+ self.tool_parser(tokenizer)
+ ] * num_choices
+ else:
+ tool_parsers = [None] * num_choices
+ except Exception as e:
+ logger.exception("Error in tool parser creation.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ stream_options = request.stream_options
+ if stream_options:
+ include_usage = stream_options.include_usage
+ include_continuous_usage = include_usage and \
+ stream_options.continuous_usage_stats
+ else:
+ include_usage, include_continuous_usage = False, False
+
+ try:
+ async for res in result_generator:
+ if res.prompt_token_ids is not None:
+ num_prompt_tokens = len(res.prompt_token_ids)
+ if res.encoder_prompt_token_ids is not None:
+ num_prompt_tokens += len(res.encoder_prompt_token_ids)
+
+ # We need to do it here, because if there are exceptions in
+ # the result_generator, it needs to be sent as the FIRST
+ # response (by the try...catch).
+ if first_iteration:
+ num_cached_tokens = res.num_cached_tokens
+ # Send first response for each request.n (index) with
+ # the role
+ role = self.get_chat_request_role(request)
+
+ # NOTE num_choices defaults to 1 so this usually executes
+ # once per request
+ for i in range(num_choices):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(
+ role=role,
+ content="",
+ ),
+ logprobs=None,
+ finish_reason=None)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+
+ # if continuous usage stats are requested, add it
+ if include_continuous_usage:
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=0,
+ total_tokens=num_prompt_tokens)
+
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+
+ # Send response to echo the input portion of the
+ # last message
+ if request.echo:
+ last_msg_content: Union[str, List[Dict[str, str]]] = ""
+ if conversation and "content" in conversation[
+ -1] and conversation[-1].get("role") == role:
+ last_msg_content = conversation[-1]["content"] or ""
+
+ if last_msg_content:
+ for i in range(num_choices):
+ choice_data = (
+ ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(
+ content=last_msg_content),
+ logprobs=None,
+ finish_reason=None))
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+ if include_continuous_usage:
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=0,
+ total_tokens=num_prompt_tokens)
+
+ data = chunk.model_dump_json(
+ exclude_unset=True)
+ yield f"data: {data}\n\n"
+ first_iteration = False
+
+ for output in res.outputs:
+ i = output.index
+ tool_parser = tool_parsers[i]
+
+ if finish_reason_sent[i]:
+ continue
+
+ if request.logprobs and request.top_logprobs is not None:
+ assert output.logprobs is not None, (
+ "Did not output logprobs")
+ logprobs = self._create_chat_logprobs(
+ token_ids=output.token_ids,
+ top_logprobs=output.logprobs,
+ tokenizer=tokenizer,
+ num_output_top_logprobs=request.top_logprobs,
+ )
+ else:
+ logprobs = None
+
+ delta_text = output.text
+
+ if not delta_text and not output.token_ids and \
+ not previous_num_tokens[i]:
+ # Chunked prefill case, don't return empty chunks
+ continue
+
+ delta_message: Optional[DeltaMessage]
+
+ # handle streaming deltas for tools with named tool_choice
+ if tool_choice_function_name:
+ delta_message = DeltaMessage(tool_calls=[
+ DeltaToolCall(function=DeltaFunctionCall(
+ name=tool_choice_function_name,
+ arguments=delta_text),
+ index=i)
+ ])
+
+ # handle streaming deltas for tools with "auto" tool choice
+ elif tool_choice_auto:
+ assert previous_texts is not None
+ assert all_previous_token_ids is not None
+ assert tool_parser is not None
+ #TODO optimize manipulation of these lists
+ previous_text = previous_texts[i]
+ previous_token_ids = all_previous_token_ids[i]
+ current_text = previous_text + delta_text
+ current_token_ids = previous_token_ids + list(
+ output.token_ids)
+
+ delta_message = (
+ tool_parser.extract_tool_calls_streaming(
+ previous_text=previous_text,
+ current_text=current_text,
+ delta_text=delta_text,
+ previous_token_ids=previous_token_ids,
+ current_token_ids=current_token_ids,
+ delta_token_ids=output.token_ids,
+ request=request))
+
+ # update the previous values for the next iteration
+ previous_texts[i] = current_text
+ all_previous_token_ids[i] = current_token_ids
+ # reasoning_content cannot be enabled with tool_choice.
+ # If it is, the tool_choice will be used instead.
+ elif self.enable_reasoning:
+ # handle reasoning_content delta
+ assert reasoning_parser is not None
+ assert previous_texts is not None
+ assert all_previous_token_ids is not None
+ previous_text = previous_texts[i]
+ previous_token_ids = all_previous_token_ids[i]
+ current_text = previous_text + delta_text
+ current_token_ids = previous_token_ids + list(
+ output.token_ids)
+
+ delta_message = (reasoning_parser.
+ extract_reasoning_content_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ output.token_ids,
+ ))
+
+ # update the previous values for the next iteration
+ previous_texts[i] = current_text
+ all_previous_token_ids[i] = current_token_ids
+
+ # handle streaming just a content delta
+ else:
+ delta_message = DeltaMessage(content=delta_text)
+
+ # set the previous values for the next iteration
+ previous_num_tokens[i] += len(output.token_ids)
+
+ # if the message delta is None (e.g. because it was a
+ # "control token" for tool calls or the parser otherwise
+ # wasn't ready to send a token, then
+ # get the next token without streaming a chunk
+ if delta_message is None:
+ continue
+
+ if output.finish_reason is None:
+ # Send token-by-token response for each request.n
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=delta_message,
+ logprobs=logprobs,
+ finish_reason=None)
+
+ # if the model is finished generating
+ else:
+ # check to make sure we haven't "forgotten" to stream
+ # any tokens that were generated but previously
+ # matched by partial json parsing
+ # only happens if we are NOT using guided decoding
+ auto_tools_called = False
+ if tool_parser:
+ auto_tools_called = len(
+ tool_parser.prev_tool_call_arr) > 0
+ index = len(tool_parser.prev_tool_call_arr
+ ) - 1 if auto_tools_called else 0
+ else:
+ index = 0
+
+ if self._should_check_for_unstreamed_tool_arg_tokens(
+ delta_message, output) and tool_parser:
+ latest_delta_len = 0
+ if ((isinstance(
+ delta_message.tool_calls[0].function,
+ DeltaFunctionCall)) and isinstance(
+ delta_message.tool_calls[0].function.
+ arguments, str)):
+ latest_delta_len = len(
+ delta_message.tool_calls[0].function.
+ arguments)
+
+ # get the expected call based on partial JSON
+ # parsing which "autocompletes" the JSON
+ expected_call = json.dumps(
+ tool_parser.prev_tool_call_arr[index].get(
+ "arguments", {}),
+ ensure_ascii=False)
+
+ # get what we've streamed so far for arguments
+ # for the current tool
+ actual_call = tool_parser.streamed_args_for_tool[
+ index]
+ if (latest_delta_len > 0):
+ actual_call = actual_call[:-latest_delta_len]
+
+ # check to see if there's anything left to stream
+ remaining_call = expected_call.replace(
+ actual_call, "", 1)
+ # set that as a delta message
+ delta_message = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=index,
+ function=DeltaFunctionCall(
+ arguments=remaining_call).
+ model_dump(exclude_none=True))
+ ])
+
+ # Send the finish response for each request.n only once
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=delta_message,
+ logprobs=logprobs,
+ finish_reason=output.finish_reason
+ if not auto_tools_called else "tool_calls",
+ stop_reason=output.stop_reason)
+
+ finish_reason_sent[i] = True
+
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+
+ # handle usage stats if requested & if continuous
+ if include_continuous_usage:
+ completion_tokens = previous_num_tokens[i]
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=num_prompt_tokens + completion_tokens,
+ )
+
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+
+ # once the final token is handled, if stream_options.include_usage
+ # is sent, send the usage
+ if include_usage:
+ completion_tokens = sum(previous_num_tokens)
+ final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=num_prompt_tokens +
+ completion_tokens)
+ if self.enable_prompt_tokens_details and num_cached_tokens:
+ final_usage.prompt_tokens_details = PromptTokenUsageInfo(
+ cached_tokens=num_cached_tokens)
+
+ final_usage_chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[],
+ model=model_name,
+ usage=final_usage)
+ final_usage_data = (final_usage_chunk.model_dump_json(
+ exclude_unset=True, exclude_none=True))
+ yield f"data: {final_usage_data}\n\n"
+
+ # report to FastAPI middleware aggregate usage across all choices
+ num_completion_tokens = sum(previous_num_tokens)
+ request_metadata.final_usage_info = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=num_completion_tokens,
+ total_tokens=num_prompt_tokens + num_completion_tokens)
+
+ except Exception as e:
+ # TODO: Use a vllm-specific Validation Error
+ logger.exception("Error in chat completion stream generator.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ # Send the final done message after all response.n are finished
+ yield "data: [DONE]\n\n"
+
+ async def chat_completion_full_generator(
+ self,
+ request: ChatCompletionRequest,
+ result_generator: AsyncIterator[RequestOutput],
+ request_id: str,
+ model_name: str,
+ conversation: List[ConversationMessage],
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+ ) -> Union[ErrorResponse, ChatCompletionResponse]:
+
+ created_time = int(time.time())
+ final_res: Optional[RequestOutput] = None
+
+ try:
+ async for res in result_generator:
+ final_res = res
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ assert final_res is not None
+
+ choices: List[ChatCompletionResponseChoice] = []
+
+ role = self.get_chat_request_role(request)
+ for output in final_res.outputs:
+ token_ids = output.token_ids
+ out_logprobs = output.logprobs
+
+ if request.logprobs and request.top_logprobs is not None:
+ assert out_logprobs is not None, "Did not output logprobs"
+ logprobs = self._create_chat_logprobs(
+ token_ids=token_ids,
+ top_logprobs=out_logprobs,
+ num_output_top_logprobs=request.top_logprobs,
+ tokenizer=tokenizer,
+ )
+ else:
+ logprobs = None
+
+ should_stream_with_reasoning_parsing = (
+ self._should_stream_with_reasoning_parsing(request))
+
+ # In the OpenAI API the finish_reason is "tools_called"
+ # if the tool choice is auto and the model produced a tool
+ # call. The same is not true for named function calls
+ auto_tools_called = False
+
+ if should_stream_with_reasoning_parsing and \
+ self.reasoning_parser is not None:
+ try:
+ reasoning_parser = self.reasoning_parser(tokenizer)
+ except RuntimeError as e:
+ logger.exception("Error in reasoning parser creation.")
+ return self.create_error_response(str(e))
+
+ reasoning_content, content = (
+ reasoning_parser.extract_reasoning_content(
+ output.text, request=request))
+
+ if reasoning_content:
+ message = ChatMessage(role=role,
+ content=content,
+ reasoning_content=reasoning_content)
+ else:
+ message = ChatMessage(role=role, content=output.text)
+
+ # if auto tools are not enabled, and a named tool choice using
+ # outlines is not being used
+ elif (not self.enable_auto_tools
+ or not self.tool_parser) and not isinstance(
+ request.tool_choice, ChatCompletionNamedToolChoiceParam):
+ message = ChatMessage(role=role, content=output.text)
+
+ # if the request uses tools and specified a tool choice
+ elif request.tool_choice and type(
+ request.tool_choice) is ChatCompletionNamedToolChoiceParam:
+
+ message = ChatMessage(
+ role=role,
+ content="",
+ tool_calls=[
+ ToolCall(function=FunctionCall(
+ name=request.tool_choice.function.name,
+ arguments=output.text))
+ ])
+
+ # if the request doesn't use tool choice
+ # OR specifies to not use a tool
+ elif not request.tool_choice or request.tool_choice == "none":
+
+ message = ChatMessage(role=role, content=output.text)
+
+ # handle when there are tools and tool choice is auto
+ elif request.tools and (
+ request.tool_choice == "auto"
+ or request.tool_choice is None) and self.enable_auto_tools \
+ and self.tool_parser:
+
+ try:
+ tool_parser = self.tool_parser(tokenizer)
+ except RuntimeError as e:
+ logger.exception("Error in tool parser creation.")
+ return self.create_error_response(str(e))
+
+ tool_call_info = tool_parser.extract_tool_calls(
+ output.text, request=request)
+ # In the OpenAI API the finish_reason is "tools_called"
+ # if the tool choice is auto and the model produced a tool
+ # call. The same is not true for named function calls
+ auto_tools_called = tool_call_info.tools_called
+ if tool_call_info.tools_called:
+ message = ChatMessage(role=role,
+ content=tool_call_info.content,
+ tool_calls=tool_call_info.tool_calls)
+
+ else:
+ # FOR NOW make it a chat message; we will have to detect
+ # the type to make it later.
+ message = ChatMessage(role=role, content=output.text)
+
+ # undetermined case that is still important to handle
+ else:
+ logger.error(
+ "Error in chat_completion_full_generator - cannot determine"
+ " if tools should be extracted. Returning a standard chat "
+ "completion.")
+ message = ChatMessage(role=role, content=output.text)
+
+ choice_data = ChatCompletionResponseChoice(
+ index=output.index,
+ message=message,
+ logprobs=logprobs,
+ finish_reason="tool_calls" if auto_tools_called else
+ output.finish_reason if output.finish_reason else "stop",
+ stop_reason=output.stop_reason)
+ choices.append(choice_data)
+
+ if request.echo:
+ last_msg_content: Union[str, List[Dict[str, str]]] = ""
+ if conversation and "content" in conversation[-1] and conversation[
+ -1].get("role") == role:
+ last_msg_content = conversation[-1]["content"] or ""
+ if isinstance(last_msg_content, list):
+ last_msg_content = "\n".join(msg['text']
+ for msg in last_msg_content)
+
+ for choice in choices:
+ full_message = last_msg_content + (choice.message.content
+ or "")
+ choice.message.content = full_message
+
+ assert final_res.prompt_token_ids is not None
+ num_prompt_tokens = len(final_res.prompt_token_ids)
+ if final_res.encoder_prompt_token_ids is not None:
+ num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
+ num_generated_tokens = sum(
+ len(output.token_ids) for output in final_res.outputs)
+ usage = UsageInfo(prompt_tokens=num_prompt_tokens,
+ completion_tokens=num_generated_tokens,
+ total_tokens=num_prompt_tokens +
+ num_generated_tokens)
+ if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
+ usage.prompt_tokens_details = PromptTokenUsageInfo(
+ cached_tokens=final_res.num_cached_tokens)
+
+ request_metadata.final_usage_info = usage
+
+ response = ChatCompletionResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ choices=choices,
+ usage=usage,
+ prompt_logprobs=final_res.prompt_logprobs,
+ )
+
+ return response
+
+ def _get_top_logprobs(
+ self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
+ tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
+ return [
+ ChatCompletionLogProb(token=(token := self._get_decoded_token(
+ p[1],
+ p[0],
+ tokenizer,
+ return_as_token_id=self.return_tokens_as_token_ids)),
+ logprob=max(p[1].logprob, -9999.0),
+ bytes=list(
+ token.encode("utf-8", errors="replace")))
+ for i, p in enumerate(logprobs.items())
+ if top_logprobs and i < top_logprobs
+ ]
+
+ def _create_chat_logprobs(
+ self,
+ token_ids: GenericSequence[int],
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
+ tokenizer: AnyTokenizer,
+ num_output_top_logprobs: Optional[int] = None,
+ ) -> ChatCompletionLogProbs:
+ """Create OpenAI-style logprobs."""
+ logprobs_content: List[ChatCompletionLogProbsContent] = []
+
+ for i, token_id in enumerate(token_ids):
+ step_top_logprobs = top_logprobs[i]
+ if step_top_logprobs is None:
+ token = tokenizer.decode(token_id)
+ if self.return_tokens_as_token_ids:
+ token = f"token_id:{token_id}"
+
+ logprobs_content.append(
+ ChatCompletionLogProbsContent(
+ token=token,
+ bytes=list(token.encode("utf-8", errors="replace")),
+ ))
+ else:
+ step_token = step_top_logprobs[token_id]
+ step_decoded = step_token.decoded_token
+
+ logprobs_content.append(
+ ChatCompletionLogProbsContent(
+ token=self._get_decoded_token(
+ step_token,
+ token_id,
+ tokenizer,
+ self.return_tokens_as_token_ids,
+ ),
+ logprob=max(step_token.logprob, -9999.0),
+ bytes=None if step_decoded is None else list(
+ step_decoded.encode("utf-8", errors="replace")),
+ top_logprobs=self._get_top_logprobs(
+ step_top_logprobs,
+ num_output_top_logprobs,
+ tokenizer,
+ ),
+ ))
+
+ return ChatCompletionLogProbs(content=logprobs_content)
+
+ def _should_stream_with_auto_tool_parsing(self,
+ request: ChatCompletionRequest):
+ """
+ Utility function to check if streamed tokens should go through the tool
+ call parser that was configured.
+
+ We only want to do this IF user-provided tools are set, a tool parser
+ is configured, "auto" tool choice is enabled, and the request's tool
+ choice field indicates that "auto" tool choice should be used.
+ """
+ return (request.tools and self.tool_parser and self.enable_auto_tools
+ and request.tool_choice in ['auto', None])
+
+ def _should_stream_with_reasoning_parsing(self,
+ request: ChatCompletionRequest):
+ """
+ Utility function to check if streamed tokens should go through the
+ reasoning parser that was configured.
+
+ We only want to do this IF reasoning is enabled and a reasoning
+ parser is configured.
+ """
+ return self.enable_reasoning and self.reasoning_parser is not None
+
+ def _should_check_for_unstreamed_tool_arg_tokens(
+ self,
+ delta_message: Optional[DeltaMessage],
+ output: CompletionOutput,
+ ) -> bool:
+ """
+ Check to see if we should check for unstreamed tool arguments tokens.
+ This is only applicable when auto tool parsing is enabled, the delta
+ is a tool call with arguments.
+ """
+
+ # yapf: disable
+ return bool(
+ # if there is a delta message that includes tool calls which
+ # include a function that has arguments
+ output.finish_reason is not None
+ and self.enable_auto_tools and self.tool_parser and delta_message
+ and delta_message.tool_calls and delta_message.tool_calls[0]
+ and delta_message.tool_calls[0].function
+ and delta_message.tool_calls[0].function.arguments is not None
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ad263e7fbe5049dcab508139b67d0bbc2d5aa9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py
@@ -0,0 +1,547 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import time
+from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Sequence as GenericSequence
+from typing import Tuple, Union, cast
+
+from fastapi import Request
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.logger import RequestLogger
+# yapf conflicts with isort for this block
+# yapf: disable
+from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ CompletionResponseStreamChoice,
+ CompletionStreamResponse,
+ ErrorResponse,
+ RequestResponseMetadata,
+ UsageInfo)
+# yapf: enable
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.logger import init_logger
+from vllm.outputs import RequestOutput
+from vllm.sampling_params import BeamSearchParams, SamplingParams
+from vllm.sequence import Logprob
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import merge_async_iterators
+
+logger = init_logger(__name__)
+
+
+class OpenAIServingCompletion(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ return_tokens_as_token_ids: bool = False,
+ ):
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger,
+ return_tokens_as_token_ids=return_tokens_as_token_ids)
+ diff_sampling_param = self.model_config.get_diff_sampling_param()
+ if diff_sampling_param:
+ logger.info(
+ "Overwriting default completion sampling param with: %s",
+ diff_sampling_param)
+
+ async def create_completion(
+ self,
+ request: CompletionRequest,
+ raw_request: Optional[Request] = None,
+ ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
+ """Completion API similar to OpenAI's API.
+
+ See https://platform.openai.com/docs/api-reference/completions/create
+ for the API specification. This API mimics the OpenAI Completion API.
+
+ NOTE: Currently we do not support the following feature:
+ - suffix (the language models we currently support do not support
+ suffix)
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ # If the engine is dead, raise the engine's DEAD_ERROR.
+ # This is required for the streaming case, where we return a
+ # success status before we actually start generating text :).
+ if self.engine_client.errored:
+ raise self.engine_client.dead_error
+
+ # Return error for unsupported features.
+ if request.suffix is not None:
+ return self.create_error_response(
+ "suffix is not currently supported")
+
+ request_id = f"cmpl-{self._base_request_id(raw_request)}"
+ created_time = int(time.time())
+
+ request_metadata = RequestResponseMetadata(request_id=request_id)
+ if raw_request:
+ raw_request.state.request_metadata = request_metadata
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ request_prompts, engine_prompts = await self._preprocess_completion(
+ request,
+ tokenizer,
+ request.prompt,
+ truncate_prompt_tokens=request.truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[RequestOutput, None]] = []
+ try:
+ for i, engine_prompt in enumerate(engine_prompts):
+ sampling_params: Union[SamplingParams, BeamSearchParams]
+ default_max_tokens = self.max_model_len - len(
+ engine_prompt["prompt_token_ids"])
+ # Build default sampling params
+ default_sampling_params = (
+ self.model_config.get_diff_sampling_param())
+ if request.use_beam_search:
+ sampling_params = request.to_beam_search_params(
+ default_max_tokens, default_sampling_params)
+ else:
+ sampling_params = request.to_sampling_params(
+ default_max_tokens,
+ self.model_config.logits_processor_pattern,
+ default_sampling_params)
+
+ request_id_item = f"{request_id}-{i}"
+
+ self._log_inputs(request_id_item,
+ request_prompts[i],
+ params=sampling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ if isinstance(sampling_params, BeamSearchParams):
+ generator = self.engine_client.beam_search(
+ prompt=engine_prompt,
+ request_id=request_id,
+ params=sampling_params,
+ )
+ else:
+ generator = self.engine_client.generate(
+ engine_prompt,
+ sampling_params,
+ request_id_item,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ trace_headers=trace_headers,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ result_generator = merge_async_iterators(*generators)
+
+ model_name = self.models.model_name(lora_request)
+ num_prompts = len(engine_prompts)
+
+ # Similar to the OpenAI API, when n != best_of, we do not stream the
+ # results. In addition, we do not stream the results when use
+ # beam search.
+ stream = (request.stream
+ and (request.best_of is None or request.n == request.best_of)
+ and not request.use_beam_search)
+
+ # Streaming response
+ if stream:
+ return self.completion_stream_generator(
+ request,
+ result_generator,
+ request_id,
+ created_time,
+ model_name,
+ num_prompts=num_prompts,
+ tokenizer=tokenizer,
+ request_metadata=request_metadata)
+
+ # Non-streaming response
+ final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
+ try:
+ async for i, res in result_generator:
+ final_res_batch[i] = res
+
+ for i, final_res in enumerate(final_res_batch):
+ assert final_res is not None
+
+ # The output should contain the input text
+ # We did not pass it into vLLM engine to avoid being redundant
+ # with the inputs token IDs
+ if final_res.prompt is None:
+ final_res.prompt = request_prompts[i]["prompt"]
+
+ final_res_batch_checked = cast(List[RequestOutput],
+ final_res_batch)
+
+ response = self.request_output_to_completion_response(
+ final_res_batch_checked,
+ request,
+ request_id,
+ created_time,
+ model_name,
+ tokenizer,
+ request_metadata,
+ )
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ # When user requests streaming but we don't stream, we still need to
+ # return a streaming response with a single event.
+ if request.stream:
+ response_json = response.model_dump_json()
+
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
+ yield f"data: {response_json}\n\n"
+ yield "data: [DONE]\n\n"
+
+ return fake_stream_generator()
+
+ return response
+
+ async def completion_stream_generator(
+ self,
+ request: CompletionRequest,
+ result_generator: AsyncIterator[Tuple[int, RequestOutput]],
+ request_id: str,
+ created_time: int,
+ model_name: str,
+ num_prompts: int,
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+ ) -> AsyncGenerator[str, None]:
+ num_choices = 1 if request.n is None else request.n
+ previous_text_lens = [0] * num_choices * num_prompts
+ previous_num_tokens = [0] * num_choices * num_prompts
+ has_echoed = [False] * num_choices * num_prompts
+ num_prompt_tokens = [0] * num_prompts
+
+ stream_options = request.stream_options
+ if stream_options:
+ include_usage = stream_options.include_usage
+ include_continuous_usage = include_usage and \
+ stream_options.continuous_usage_stats
+ else:
+ include_usage, include_continuous_usage = False, False
+
+ try:
+ async for prompt_idx, res in result_generator:
+ prompt_token_ids = res.prompt_token_ids
+ prompt_logprobs = res.prompt_logprobs
+ prompt_text = res.prompt
+
+ # Prompt details are excluded from later streamed outputs
+ if res.prompt_token_ids is not None:
+ num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
+
+ delta_token_ids: GenericSequence[int]
+ out_logprobs: Optional[GenericSequence[Optional[Dict[
+ int, Logprob]]]]
+
+ for output in res.outputs:
+ i = output.index + prompt_idx * num_choices
+
+ assert request.max_tokens is not None
+ if request.echo and not has_echoed[i]:
+ assert prompt_token_ids is not None
+ assert prompt_text is not None
+ if request.max_tokens == 0:
+ # only return the prompt
+ delta_text = prompt_text
+ delta_token_ids = prompt_token_ids
+ out_logprobs = prompt_logprobs
+ else:
+ assert prompt_logprobs is not None
+ # echo the prompt and first token
+ delta_text = prompt_text + output.text
+ delta_token_ids = [
+ *prompt_token_ids, *output.token_ids
+ ]
+ out_logprobs = [
+ *prompt_logprobs,
+ *(output.logprobs or []),
+ ]
+ has_echoed[i] = True
+ else:
+ # return just the delta
+ delta_text = output.text
+ delta_token_ids = output.token_ids
+ out_logprobs = output.logprobs
+
+ if not delta_text and not delta_token_ids \
+ and not previous_num_tokens[i]:
+ # Chunked prefill case, don't return empty chunks
+ continue
+
+ if request.logprobs is not None:
+ assert out_logprobs is not None, (
+ "Did not output logprobs")
+ logprobs = self._create_completion_logprobs(
+ token_ids=delta_token_ids,
+ top_logprobs=out_logprobs,
+ num_output_top_logprobs=request.logprobs,
+ tokenizer=tokenizer,
+ initial_text_offset=previous_text_lens[i],
+ )
+ else:
+ logprobs = None
+
+ previous_text_lens[i] += len(output.text)
+ previous_num_tokens[i] += len(output.token_ids)
+ finish_reason = output.finish_reason
+ stop_reason = output.stop_reason
+
+ chunk = CompletionStreamResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ choices=[
+ CompletionResponseStreamChoice(
+ index=i,
+ text=delta_text,
+ logprobs=logprobs,
+ finish_reason=finish_reason,
+ stop_reason=stop_reason,
+ )
+ ])
+ if include_continuous_usage:
+ prompt_tokens = num_prompt_tokens[prompt_idx]
+ completion_tokens = previous_num_tokens[i]
+ chunk.usage = UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+
+ response_json = chunk.model_dump_json(exclude_unset=False)
+ yield f"data: {response_json}\n\n"
+
+ total_prompt_tokens = sum(num_prompt_tokens)
+ total_completion_tokens = sum(previous_num_tokens)
+ final_usage_info = UsageInfo(
+ prompt_tokens=total_prompt_tokens,
+ completion_tokens=total_completion_tokens,
+ total_tokens=total_prompt_tokens + total_completion_tokens)
+
+ if include_usage:
+ final_usage_chunk = CompletionStreamResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ choices=[],
+ usage=final_usage_info,
+ )
+ final_usage_data = (final_usage_chunk.model_dump_json(
+ exclude_unset=False, exclude_none=True))
+ yield f"data: {final_usage_data}\n\n"
+
+ # report to FastAPI middleware aggregate usage across all choices
+ request_metadata.final_usage_info = final_usage_info
+
+ except Exception as e:
+ # TODO: Use a vllm-specific Validation Error
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ yield "data: [DONE]\n\n"
+
+ def request_output_to_completion_response(
+ self,
+ final_res_batch: List[RequestOutput],
+ request: CompletionRequest,
+ request_id: str,
+ created_time: int,
+ model_name: str,
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+ ) -> CompletionResponse:
+ choices: List[CompletionResponseChoice] = []
+ num_prompt_tokens = 0
+ num_generated_tokens = 0
+
+ for final_res in final_res_batch:
+ prompt_token_ids = final_res.prompt_token_ids
+ assert prompt_token_ids is not None
+ prompt_logprobs = final_res.prompt_logprobs
+ if prompt_logprobs:
+ for logprob_dict in prompt_logprobs:
+ if logprob_dict:
+ for logprob_values in logprob_dict.values():
+ if logprob_values.logprob == float('-inf'):
+ logprob_values.logprob = -9999.0
+ prompt_text = final_res.prompt
+
+ token_ids: GenericSequence[int]
+ out_logprobs: Optional[GenericSequence[Optional[Dict[int,
+ Logprob]]]]
+
+ for output in final_res.outputs:
+ assert request.max_tokens is not None
+ if request.echo:
+ assert prompt_text is not None
+ if request.max_tokens == 0:
+ token_ids = prompt_token_ids
+ out_logprobs = prompt_logprobs
+ output_text = prompt_text
+ else:
+ token_ids = [*prompt_token_ids, *output.token_ids]
+
+ if request.logprobs is None:
+ out_logprobs = None
+ else:
+ assert prompt_logprobs is not None
+ assert output.logprobs is not None
+ out_logprobs = [
+ *prompt_logprobs,
+ *output.logprobs,
+ ]
+
+ output_text = prompt_text + output.text
+ else:
+ token_ids = output.token_ids
+ out_logprobs = output.logprobs
+ output_text = output.text
+
+ if request.logprobs is not None:
+ assert out_logprobs is not None, "Did not output logprobs"
+ logprobs = self._create_completion_logprobs(
+ token_ids=token_ids,
+ top_logprobs=out_logprobs,
+ tokenizer=tokenizer,
+ num_output_top_logprobs=request.logprobs,
+ )
+ else:
+ logprobs = None
+
+ choice_data = CompletionResponseChoice(
+ index=len(choices),
+ text=output_text,
+ logprobs=logprobs,
+ finish_reason=output.finish_reason,
+ stop_reason=output.stop_reason,
+ prompt_logprobs=final_res.prompt_logprobs,
+ )
+ choices.append(choice_data)
+
+ num_generated_tokens += len(output.token_ids)
+
+ num_prompt_tokens += len(prompt_token_ids)
+
+ usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=num_generated_tokens,
+ total_tokens=num_prompt_tokens + num_generated_tokens,
+ )
+
+ request_metadata.final_usage_info = usage
+
+ return CompletionResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ choices=choices,
+ usage=usage,
+ )
+
+ def _create_completion_logprobs(
+ self,
+ token_ids: GenericSequence[int],
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
+ num_output_top_logprobs: int,
+ tokenizer: AnyTokenizer,
+ initial_text_offset: int = 0,
+ ) -> CompletionLogProbs:
+ """Create logprobs for OpenAI Completion API."""
+ out_text_offset: List[int] = []
+ out_token_logprobs: List[Optional[float]] = []
+ out_tokens: List[str] = []
+ out_top_logprobs: List[Optional[Dict[str, float]]] = []
+
+ last_token_len = 0
+
+ for i, token_id in enumerate(token_ids):
+ step_top_logprobs = top_logprobs[i]
+ if step_top_logprobs is None:
+ token = tokenizer.decode(token_id)
+ if self.return_tokens_as_token_ids:
+ token = f"token_id:{token_id}"
+
+ out_tokens.append(token)
+ out_token_logprobs.append(None)
+ out_top_logprobs.append(None)
+ else:
+ step_token = step_top_logprobs[token_id]
+
+ token = self._get_decoded_token(
+ step_token,
+ token_id,
+ tokenizer,
+ return_as_token_id=self.return_tokens_as_token_ids,
+ )
+ token_logprob = max(step_token.logprob, -9999.0)
+
+ out_tokens.append(token)
+ out_token_logprobs.append(token_logprob)
+
+ # makes sure to add the top num_output_top_logprobs + 1
+ # logprobs, as defined in the openai API
+ # (cf. https://github.com/openai/openai-openapi/blob/
+ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
+ out_top_logprobs.append({
+ # Convert float("-inf") to the
+ # JSON-serializable float that OpenAI uses
+ self._get_decoded_token(top_lp[1],
+ top_lp[0],
+ tokenizer,
+ return_as_token_id=self.return_tokens_as_token_ids):
+ max(top_lp[1].logprob, -9999.0)
+ for i, top_lp in enumerate(step_top_logprobs.items())
+ if num_output_top_logprobs >= i
+ })
+
+ if len(out_text_offset) == 0:
+ out_text_offset.append(initial_text_offset)
+ else:
+ out_text_offset.append(out_text_offset[-1] + last_token_len)
+ last_token_len = len(token)
+
+ return CompletionLogProbs(
+ text_offset=out_text_offset,
+ token_logprobs=out_token_logprobs,
+ tokens=out_tokens,
+ top_logprobs=out_top_logprobs,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f8ad90ddcb3d67d56e49ccfa39bcc4ec2d135d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py
@@ -0,0 +1,242 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import base64
+import time
+from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
+
+import numpy as np
+from fastapi import Request
+from typing_extensions import assert_never
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
+ EmbeddingRequest,
+ EmbeddingResponse,
+ EmbeddingResponseData,
+ ErrorResponse, UsageInfo)
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.logger import init_logger
+from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
+ PoolingRequestOutput)
+from vllm.utils import merge_async_iterators
+
+logger = init_logger(__name__)
+
+
+def _get_embedding(
+ output: EmbeddingOutput,
+ encoding_format: Literal["float", "base64"],
+) -> Union[List[float], str]:
+ if encoding_format == "float":
+ return output.embedding
+ elif encoding_format == "base64":
+ # Force to use float32 for base64 encoding
+ # to match the OpenAI python client behavior
+ embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
+ return base64.b64encode(embedding_bytes).decode("utf-8")
+
+ assert_never(encoding_format)
+
+
+class OpenAIServingEmbedding(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ chat_template: Optional[str],
+ chat_template_content_format: ChatTemplateContentFormatOption,
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger)
+
+ self.chat_template = chat_template
+ self.chat_template_content_format: Final = chat_template_content_format
+
+ async def create_embedding(
+ self,
+ request: EmbeddingRequest,
+ raw_request: Optional[Request] = None,
+ ) -> Union[EmbeddingResponse, ErrorResponse]:
+ """
+ Embedding API similar to OpenAI's API.
+
+ See https://platform.openai.com/docs/api-reference/embeddings/create
+ for the API specification. This API mimics the OpenAI Embedding API.
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ encoding_format = request.encoding_format
+ if request.dimensions is not None:
+ return self.create_error_response(
+ "dimensions is currently not supported")
+
+ model_name = request.model
+ request_id = f"embd-{self._base_request_id(raw_request)}"
+ created_time = int(time.time())
+
+ truncate_prompt_tokens = None
+
+ if request.truncate_prompt_tokens is not None:
+ if request.truncate_prompt_tokens <= self.max_model_len:
+ truncate_prompt_tokens = request.truncate_prompt_tokens
+ else:
+ return self.create_error_response(
+ "truncate_prompt_tokens value is "
+ "greater than max_model_len."
+ " Please, select a smaller truncation size.")
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ if prompt_adapter_request is not None:
+ raise NotImplementedError("Prompt adapter is not supported "
+ "for embedding models")
+
+ if isinstance(request, EmbeddingChatRequest):
+ (
+ _,
+ request_prompts,
+ engine_prompts,
+ ) = await self._preprocess_chat(
+ request,
+ tokenizer,
+ request.messages,
+ chat_template=request.chat_template or self.chat_template,
+ chat_template_content_format=self.
+ chat_template_content_format,
+ # In embedding requests, we are not generating tokens,
+ # so there is no need to append extra tokens to the input
+ add_generation_prompt=False,
+ continue_final_message=False,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ else:
+ (request_prompts,
+ engine_prompts) = await self._preprocess_completion(
+ request,
+ tokenizer,
+ request.input,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
+ try:
+ pooling_params = request.to_pooling_params()
+
+ for i, engine_prompt in enumerate(engine_prompts):
+ request_id_item = f"{request_id}-{i}"
+
+ self._log_inputs(request_id_item,
+ request_prompts[i],
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ generator = self.engine_client.encode(
+ engine_prompt,
+ pooling_params,
+ request_id_item,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ result_generator = merge_async_iterators(*generators)
+
+ num_prompts = len(engine_prompts)
+
+ # Non-streaming response
+ final_res_batch: List[Optional[PoolingRequestOutput]]
+ final_res_batch = [None] * num_prompts
+ try:
+ async for i, res in result_generator:
+ final_res_batch[i] = res
+
+ assert all(final_res is not None for final_res in final_res_batch)
+
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
+ final_res_batch)
+
+ response = self.request_output_to_embedding_response(
+ final_res_batch_checked,
+ request_id,
+ created_time,
+ model_name,
+ encoding_format,
+ )
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ return response
+
+ def request_output_to_embedding_response(
+ self,
+ final_res_batch: List[PoolingRequestOutput],
+ request_id: str,
+ created_time: int,
+ model_name: str,
+ encoding_format: Literal["float", "base64"],
+ ) -> EmbeddingResponse:
+ items: List[EmbeddingResponseData] = []
+ num_prompt_tokens = 0
+
+ for idx, final_res in enumerate(final_res_batch):
+ embedding_res = EmbeddingRequestOutput.from_base(final_res)
+
+ item = EmbeddingResponseData(
+ index=idx,
+ embedding=_get_embedding(embedding_res.outputs,
+ encoding_format),
+ )
+ prompt_token_ids = final_res.prompt_token_ids
+
+ items.append(item)
+ num_prompt_tokens += len(prompt_token_ids)
+
+ usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ total_tokens=num_prompt_tokens,
+ )
+
+ return EmbeddingResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ data=items,
+ usage=usage,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d39fdcb748330545539b1e71c8cae88d71483c6
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py
@@ -0,0 +1,524 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from concurrent.futures.thread import ThreadPoolExecutor
+from http import HTTPStatus
+from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
+ Optional, Sequence, Tuple, TypedDict, Union)
+
+from fastapi import Request
+from pydantic import Field
+from starlette.datastructures import Headers
+from typing_extensions import Annotated
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+# yapf conflicts with isort for this block
+# yapf: disable
+from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+ ChatTemplateContentFormatOption,
+ ConversationMessage,
+ apply_hf_chat_template,
+ apply_mistral_chat_template,
+ parse_chat_messages_futures,
+ resolve_chat_template_content_format)
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ CompletionRequest,
+ DetokenizeRequest,
+ EmbeddingChatRequest,
+ EmbeddingCompletionRequest,
+ ErrorResponse, RerankRequest,
+ ScoreRequest,
+ TokenizeChatRequest,
+ TokenizeCompletionRequest)
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.entrypoints.openai.tool_parsers import ToolParser
+# yapf: enable
+from vllm.inputs import TokensPrompt
+from vllm.inputs.parse import parse_and_batch_prompt
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.pooling_params import PoolingParams
+from vllm.prompt_adapter.request import PromptAdapterRequest
+from vllm.sampling_params import BeamSearchParams, SamplingParams
+from vllm.sequence import Logprob
+from vllm.tracing import (contains_trace_headers, extract_trace_headers,
+ log_tracing_disabled_warning)
+from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
+from vllm.utils import is_list_of, make_async, random_uuid
+
+logger = init_logger(__name__)
+
+CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
+ EmbeddingCompletionRequest, ScoreRequest,
+ TokenizeCompletionRequest]
+
+ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
+ TokenizeChatRequest]
+
+AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
+
+
+class TextTokensPrompt(TypedDict):
+ prompt: str
+ prompt_token_ids: List[int]
+
+
+RequestPrompt = Union[List[int], str, TextTokensPrompt]
+
+
+class OpenAIServing:
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ return_tokens_as_token_ids: bool = False,
+ ):
+ super().__init__()
+
+ self.engine_client = engine_client
+ self.model_config = model_config
+ self.max_model_len = model_config.max_model_len
+
+ self.models = models
+
+ self.request_logger = request_logger
+ self.return_tokens_as_token_ids = return_tokens_as_token_ids
+
+ self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
+
+ self._tokenize_prompt_input_async = make_async(
+ self._tokenize_prompt_input, executor=self._tokenizer_executor)
+ self._tokenize_prompt_input_or_inputs_async = make_async(
+ self._tokenize_prompt_input_or_inputs,
+ executor=self._tokenizer_executor)
+
+ def create_error_response(
+ self,
+ message: str,
+ err_type: str = "BadRequestError",
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
+ return ErrorResponse(message=message,
+ type=err_type,
+ code=status_code.value)
+
+ def create_streaming_error_response(
+ self,
+ message: str,
+ err_type: str = "BadRequestError",
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
+ json_str = json.dumps({
+ "error":
+ self.create_error_response(message=message,
+ err_type=err_type,
+ status_code=status_code).model_dump()
+ })
+ return json_str
+
+ async def _check_model(
+ self,
+ request: AnyRequest,
+ ) -> Optional[ErrorResponse]:
+ if self._is_model_supported(request.model):
+ return None
+ if request.model in [
+ lora.lora_name for lora in self.models.lora_requests
+ ]:
+ return None
+ if request.model in [
+ prompt_adapter.prompt_adapter_name
+ for prompt_adapter in self.models.prompt_adapter_requests
+ ]:
+ return None
+ return self.create_error_response(
+ message=f"The model `{request.model}` does not exist.",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND)
+
+ def _maybe_get_adapters(
+ self, request: AnyRequest
+ ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
+ None, PromptAdapterRequest]]:
+ if self._is_model_supported(request.model):
+ return None, None
+ for lora in self.models.lora_requests:
+ if request.model == lora.lora_name:
+ return lora, None
+ for prompt_adapter in self.models.prompt_adapter_requests:
+ if request.model == prompt_adapter.prompt_adapter_name:
+ return None, prompt_adapter
+ # if _check_model has been called earlier, this will be unreachable
+ raise ValueError(f"The model `{request.model}` does not exist.")
+
+ def _normalize_prompt_text_to_input(
+ self,
+ request: AnyRequest,
+ tokenizer: AnyTokenizer,
+ prompt: str,
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+ add_special_tokens: bool,
+ ) -> TextTokensPrompt:
+ if (self.model_config.encoder_config is not None
+ and self.model_config.encoder_config.get(
+ "do_lower_case", False)):
+ prompt = prompt.lower()
+
+ if truncate_prompt_tokens is None:
+ encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
+ else:
+ encoded = tokenizer(prompt,
+ add_special_tokens=add_special_tokens,
+ truncation=True,
+ max_length=truncate_prompt_tokens)
+
+ input_ids = encoded.input_ids
+
+ input_text = prompt
+
+ return self._validate_input(request, input_ids, input_text)
+
+ def _normalize_prompt_tokens_to_input(
+ self,
+ request: AnyRequest,
+ tokenizer: AnyTokenizer,
+ prompt_ids: List[int],
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+ ) -> TextTokensPrompt:
+ if truncate_prompt_tokens is None:
+ input_ids = prompt_ids
+ else:
+ input_ids = prompt_ids[-truncate_prompt_tokens:]
+
+ input_text = tokenizer.decode(input_ids)
+
+ return self._validate_input(request, input_ids, input_text)
+
+ def _validate_input(
+ self,
+ request: AnyRequest,
+ input_ids: List[int],
+ input_text: str,
+ ) -> TextTokensPrompt:
+ token_num = len(input_ids)
+
+ # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
+ if isinstance(request,
+ (EmbeddingChatRequest, EmbeddingCompletionRequest,
+ ScoreRequest, RerankRequest)):
+
+ operation = "score" if isinstance(request, ScoreRequest) \
+ else "embedding generation"
+ if token_num > self.max_model_len:
+ raise ValueError(
+ f"This model's maximum context length is "
+ f"{self.max_model_len} tokens. However, you requested "
+ f"{token_num} tokens in the input for {operation}. "
+ f"Please reduce the length of the input.")
+ return TextTokensPrompt(prompt=input_text,
+ prompt_token_ids=input_ids)
+
+ # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
+ # and does not require model context length validation
+ if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
+ DetokenizeRequest)):
+ return TextTokensPrompt(prompt=input_text,
+ prompt_token_ids=input_ids)
+
+ # chat completion endpoint supports max_completion_tokens
+ if isinstance(request, ChatCompletionRequest):
+ # TODO(#9845): remove max_tokens when field dropped from OpenAI API
+ max_tokens = request.max_completion_tokens or request.max_tokens
+ else:
+ max_tokens = request.max_tokens
+ if max_tokens is None:
+ if token_num >= self.max_model_len:
+ raise ValueError(
+ f"This model's maximum context length is "
+ f"{self.max_model_len} tokens. However, you requested "
+ f"{token_num} tokens in the messages, "
+ f"Please reduce the length of the messages.")
+ elif token_num + max_tokens > self.max_model_len:
+ raise ValueError(
+ f"This model's maximum context length is "
+ f"{self.max_model_len} tokens. However, you requested "
+ f"{max_tokens + token_num} tokens "
+ f"({token_num} in the messages, "
+ f"{max_tokens} in the completion). "
+ f"Please reduce the length of the messages or completion.")
+
+ return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
+
+ def _tokenize_prompt_input(
+ self,
+ request: AnyRequest,
+ tokenizer: AnyTokenizer,
+ prompt_input: Union[str, List[int]],
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+ add_special_tokens: bool = True,
+ ) -> TextTokensPrompt:
+ """
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+ that assumes single input.
+ """
+ return next(
+ self._tokenize_prompt_inputs(
+ request,
+ tokenizer,
+ [prompt_input],
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=add_special_tokens,
+ ))
+
+ def _tokenize_prompt_inputs(
+ self,
+ request: AnyRequest,
+ tokenizer: AnyTokenizer,
+ prompt_inputs: Iterable[Union[str, List[int]]],
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+ add_special_tokens: bool = True,
+ ) -> Iterator[TextTokensPrompt]:
+ """
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+ that assumes multiple inputs.
+ """
+ for text in prompt_inputs:
+ if isinstance(text, str):
+ yield self._normalize_prompt_text_to_input(
+ request,
+ tokenizer,
+ prompt=text,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=add_special_tokens,
+ )
+ else:
+ yield self._normalize_prompt_tokens_to_input(
+ request,
+ tokenizer,
+ prompt_ids=text,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ )
+
+ def _tokenize_prompt_input_or_inputs(
+ self,
+ request: AnyRequest,
+ tokenizer: AnyTokenizer,
+ input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+ add_special_tokens: bool = True,
+ ) -> List[TextTokensPrompt]:
+ """
+ Tokenize/detokenize depending on the input format.
+
+ According to `OpenAI API `_
+ , each input can be a string or array of tokens. Note that each request
+ can pass one or more inputs.
+ """
+ # Although our type checking is based on mypy,
+ # VSCode Pyright extension should still work properly
+ # "is True" is required for Pyright to perform type narrowing
+ # See: https://github.com/microsoft/pyright/issues/7672
+ return [
+ self._normalize_prompt_text_to_input(
+ request,
+ tokenizer,
+ prompt=prompt_input["content"],
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=add_special_tokens)
+ if prompt_input["is_tokens"] is False else
+ self._normalize_prompt_tokens_to_input(
+ request,
+ tokenizer,
+ prompt_ids=prompt_input["content"],
+ truncate_prompt_tokens=truncate_prompt_tokens)
+ for prompt_input in parse_and_batch_prompt(input_or_inputs)
+ ]
+
+ async def _preprocess_completion(
+ self,
+ request: CompletionLikeRequest,
+ tokenizer: AnyTokenizer,
+ input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+ add_special_tokens: bool = True,
+ ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
+ request_prompts = await self._tokenize_prompt_input_or_inputs_async(
+ request,
+ tokenizer,
+ input_or_inputs,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=add_special_tokens,
+ )
+
+ engine_prompts = [
+ TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
+ for request_prompt in request_prompts
+ ]
+
+ return request_prompts, engine_prompts
+
+ async def _preprocess_chat(
+ self,
+ request: ChatLikeRequest,
+ tokenizer: AnyTokenizer,
+ messages: List[ChatCompletionMessageParam],
+ chat_template: Optional[str],
+ chat_template_content_format: ChatTemplateContentFormatOption,
+ add_generation_prompt: bool = True,
+ continue_final_message: bool = False,
+ tool_dicts: Optional[List[Dict[str, Any]]] = None,
+ documents: Optional[List[Dict[str, str]]] = None,
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
+ tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+ add_special_tokens: bool = False,
+ ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
+ List[TokensPrompt]]:
+ resolved_content_format = resolve_chat_template_content_format(
+ chat_template,
+ chat_template_content_format,
+ tokenizer,
+ )
+ conversation, mm_data_future = parse_chat_messages_futures(
+ messages,
+ self.model_config,
+ tokenizer,
+ content_format=resolved_content_format,
+ )
+
+ _chat_template_kwargs: Dict[str, Any] = dict(
+ chat_template=chat_template,
+ add_generation_prompt=add_generation_prompt,
+ continue_final_message=continue_final_message,
+ tools=tool_dicts,
+ documents=documents,
+ )
+ _chat_template_kwargs.update(chat_template_kwargs or {})
+
+ request_prompt: Union[str, List[int]]
+ is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
+ if is_mistral_tokenizer:
+ request_prompt = apply_mistral_chat_template(
+ tokenizer,
+ messages=messages,
+ **_chat_template_kwargs,
+ )
+ else:
+ request_prompt = apply_hf_chat_template(
+ tokenizer,
+ conversation=conversation,
+ **_chat_template_kwargs,
+ )
+
+ mm_data = await mm_data_future
+
+ # tool parsing is done only if a tool_parser has been set and if
+ # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
+ # is set, we want to prevent parsing a tool_call hallucinated by the LLM
+ should_parse_tools = tool_parser is not None and (hasattr(
+ request, "tool_choice") and request.tool_choice != "none")
+
+ if should_parse_tools:
+ if not isinstance(request, ChatCompletionRequest):
+ msg = "Tool usage is only supported for Chat Completions API"
+ raise NotImplementedError(msg)
+
+ request = tool_parser(tokenizer).adjust_request( # type: ignore
+ request=request)
+
+ if isinstance(request_prompt, str):
+ prompt_inputs = await self._tokenize_prompt_input_async(
+ request,
+ tokenizer,
+ request_prompt,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=add_special_tokens,
+ )
+ else:
+ # For MistralTokenizer
+ assert is_list_of(request_prompt, int), (
+ "Prompt has to be either a string or a list of token ids")
+ prompt_inputs = TextTokensPrompt(
+ prompt=tokenizer.decode(request_prompt),
+ prompt_token_ids=request_prompt)
+
+ engine_prompt = TokensPrompt(
+ prompt_token_ids=prompt_inputs["prompt_token_ids"])
+ if mm_data is not None:
+ engine_prompt["multi_modal_data"] = mm_data
+
+ return conversation, [request_prompt], [engine_prompt]
+
+ def _log_inputs(
+ self,
+ request_id: str,
+ inputs: RequestPrompt,
+ params: Optional[Union[SamplingParams, PoolingParams,
+ BeamSearchParams]],
+ lora_request: Optional[LoRARequest],
+ prompt_adapter_request: Optional[PromptAdapterRequest],
+ ) -> None:
+ if self.request_logger is None:
+ return
+
+ if isinstance(inputs, str):
+ prompt = inputs
+ prompt_token_ids = None
+ elif isinstance(inputs, list):
+ prompt = None
+ prompt_token_ids = inputs
+ else:
+ prompt = inputs["prompt"]
+ prompt_token_ids = inputs["prompt_token_ids"]
+
+ self.request_logger.log_inputs(
+ request_id,
+ prompt,
+ prompt_token_ids,
+ params=params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ )
+
+ async def _get_trace_headers(
+ self,
+ headers: Headers,
+ ) -> Optional[Mapping[str, str]]:
+ is_tracing_enabled = await self.engine_client.is_tracing_enabled()
+
+ if is_tracing_enabled:
+ return extract_trace_headers(headers)
+
+ if contains_trace_headers(headers):
+ log_tracing_disabled_warning()
+
+ return None
+
+ @staticmethod
+ def _base_request_id(raw_request: Optional[Request],
+ default: Optional[str] = None) -> Optional[str]:
+ """Pulls the request id to use from a header, if provided"""
+ default = default or random_uuid()
+ if raw_request is None:
+ return default
+
+ return raw_request.headers.get("X-Request-Id", default)
+
+ @staticmethod
+ def _get_decoded_token(logprob: Logprob,
+ token_id: int,
+ tokenizer: AnyTokenizer,
+ return_as_token_id: bool = False) -> str:
+ if return_as_token_id:
+ return f"token_id:{token_id}"
+
+ if logprob.decoded_token is not None:
+ return logprob.decoded_token
+ return tokenizer.decode(token_id)
+
+ def _is_model_supported(self, model_name):
+ return self.models.is_base_model(model_name)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f917a48519016c7300cac638796649bc79fc7a5d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py
@@ -0,0 +1,244 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import pathlib
+from dataclasses import dataclass
+from http import HTTPStatus
+from typing import List, Optional, Union
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.openai.protocol import (ErrorResponse,
+ LoadLoraAdapterRequest,
+ ModelCard, ModelList,
+ ModelPermission,
+ UnloadLoraAdapterRequest)
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.prompt_adapter.request import PromptAdapterRequest
+from vllm.utils import AtomicCounter
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class BaseModelPath:
+ name: str
+ model_path: str
+
+
+@dataclass
+class PromptAdapterPath:
+ name: str
+ local_path: str
+
+
+@dataclass
+class LoRAModulePath:
+ name: str
+ path: str
+ base_model_name: Optional[str] = None
+
+
+class OpenAIServingModels:
+ """Shared instance to hold data about the loaded base model(s) and adapters.
+
+ Handles the routes:
+ - /v1/models
+ - /v1/load_lora_adapter
+ - /v1/unload_lora_adapter
+ """
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ base_model_paths: List[BaseModelPath],
+ *,
+ lora_modules: Optional[List[LoRAModulePath]] = None,
+ prompt_adapters: Optional[List[PromptAdapterPath]] = None,
+ ):
+ super().__init__()
+
+ self.base_model_paths = base_model_paths
+ self.max_model_len = model_config.max_model_len
+ self.engine_client = engine_client
+
+ self.static_lora_modules = lora_modules
+ self.lora_requests: List[LoRARequest] = []
+ self.lora_id_counter = AtomicCounter(0)
+
+ self.prompt_adapter_requests = []
+ if prompt_adapters is not None:
+ for i, prompt_adapter in enumerate(prompt_adapters, start=1):
+ with pathlib.Path(prompt_adapter.local_path,
+ "adapter_config.json").open() as f:
+ adapter_config = json.load(f)
+ num_virtual_tokens = adapter_config["num_virtual_tokens"]
+ self.prompt_adapter_requests.append(
+ PromptAdapterRequest(
+ prompt_adapter_name=prompt_adapter.name,
+ prompt_adapter_id=i,
+ prompt_adapter_local_path=prompt_adapter.local_path,
+ prompt_adapter_num_virtual_tokens=num_virtual_tokens))
+
+ async def init_static_loras(self):
+ """Loads all static LoRA modules.
+ Raises if any fail to load"""
+ if self.static_lora_modules is None:
+ return
+ for lora in self.static_lora_modules:
+ load_request = LoadLoraAdapterRequest(lora_path=lora.path,
+ lora_name=lora.name)
+ load_result = await self.load_lora_adapter(
+ request=load_request, base_model_name=lora.base_model_name)
+ if isinstance(load_result, ErrorResponse):
+ raise ValueError(load_result.message)
+
+ def is_base_model(self, model_name):
+ return any(model.name == model_name for model in self.base_model_paths)
+
+ def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
+ """Returns the appropriate model name depending on the availability
+ and support of the LoRA or base model.
+ Parameters:
+ - lora: LoRARequest that contain a base_model_name.
+ Returns:
+ - str: The name of the base model or the first available model path.
+ """
+ if lora_request is not None:
+ return lora_request.lora_name
+ return self.base_model_paths[0].name
+
+ async def show_available_models(self) -> ModelList:
+ """Show available models. This includes the base model and all
+ adapters"""
+ model_cards = [
+ ModelCard(id=base_model.name,
+ max_model_len=self.max_model_len,
+ root=base_model.model_path,
+ permission=[ModelPermission()])
+ for base_model in self.base_model_paths
+ ]
+ lora_cards = [
+ ModelCard(id=lora.lora_name,
+ root=lora.local_path,
+ parent=lora.base_model_name if lora.base_model_name else
+ self.base_model_paths[0].name,
+ permission=[ModelPermission()])
+ for lora in self.lora_requests
+ ]
+ prompt_adapter_cards = [
+ ModelCard(id=prompt_adapter.prompt_adapter_name,
+ root=self.base_model_paths[0].name,
+ permission=[ModelPermission()])
+ for prompt_adapter in self.prompt_adapter_requests
+ ]
+ model_cards.extend(lora_cards)
+ model_cards.extend(prompt_adapter_cards)
+ return ModelList(data=model_cards)
+
+ async def load_lora_adapter(
+ self,
+ request: LoadLoraAdapterRequest,
+ base_model_name: Optional[str] = None
+ ) -> Union[ErrorResponse, str]:
+ error_check_ret = await self._check_load_lora_adapter_request(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ lora_name, lora_path = request.lora_name, request.lora_path
+ unique_id = self.lora_id_counter.inc(1)
+ lora_request = LoRARequest(lora_name=lora_name,
+ lora_int_id=unique_id,
+ lora_path=lora_path)
+ if base_model_name is not None and self.is_base_model(base_model_name):
+ lora_request.base_model_name = base_model_name
+
+ # Validate that the adapter can be loaded into the engine
+ # This will also pre-load it for incoming requests
+ try:
+ await self.engine_client.add_lora(lora_request)
+ except BaseException as e:
+ error_type = "BadRequestError"
+ status_code = HTTPStatus.BAD_REQUEST
+ if isinstance(e, ValueError) and "No adapter found" in str(e):
+ error_type = "NotFoundError"
+ status_code = HTTPStatus.NOT_FOUND
+
+ return create_error_response(message=str(e),
+ err_type=error_type,
+ status_code=status_code)
+
+ self.lora_requests.append(lora_request)
+ logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
+ lora_path)
+ return f"Success: LoRA adapter '{lora_name}' added successfully."
+
+ async def unload_lora_adapter(
+ self,
+ request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
+ error_check_ret = await self._check_unload_lora_adapter_request(request
+ )
+ if error_check_ret is not None:
+ return error_check_ret
+
+ lora_name = request.lora_name
+ self.lora_requests = [
+ lora_request for lora_request in self.lora_requests
+ if lora_request.lora_name != lora_name
+ ]
+ logger.info("Removed LoRA adapter: name '%s'", lora_name)
+ return f"Success: LoRA adapter '{lora_name}' removed successfully."
+
+ async def _check_load_lora_adapter_request(
+ self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
+ # Check if both 'lora_name' and 'lora_path' are provided
+ if not request.lora_name or not request.lora_path:
+ return create_error_response(
+ message="Both 'lora_name' and 'lora_path' must be provided.",
+ err_type="InvalidUserInput",
+ status_code=HTTPStatus.BAD_REQUEST)
+
+ # Check if the lora adapter with the given name already exists
+ if any(lora_request.lora_name == request.lora_name
+ for lora_request in self.lora_requests):
+ return create_error_response(
+ message=
+ f"The lora adapter '{request.lora_name}' has already been "
+ "loaded.",
+ err_type="InvalidUserInput",
+ status_code=HTTPStatus.BAD_REQUEST)
+
+ return None
+
+ async def _check_unload_lora_adapter_request(
+ self,
+ request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
+ # Check if either 'lora_name' or 'lora_int_id' is provided
+ if not request.lora_name and not request.lora_int_id:
+ return create_error_response(
+ message=
+ "either 'lora_name' and 'lora_int_id' needs to be provided.",
+ err_type="InvalidUserInput",
+ status_code=HTTPStatus.BAD_REQUEST)
+
+ # Check if the lora adapter with the given name exists
+ if not any(lora_request.lora_name == request.lora_name
+ for lora_request in self.lora_requests):
+ return create_error_response(
+ message=
+ f"The lora adapter '{request.lora_name}' cannot be found.",
+ err_type="NotFoundError",
+ status_code=HTTPStatus.NOT_FOUND)
+
+ return None
+
+
+def create_error_response(
+ message: str,
+ err_type: str = "BadRequestError",
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
+ return ErrorResponse(message=message,
+ type=err_type,
+ code=status_code.value)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..01a3d211f6ba633988782cbd7af6d71e556f72b2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py
@@ -0,0 +1,235 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import base64
+import time
+from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
+
+import numpy as np
+from fastapi import Request
+from typing_extensions import assert_never
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (ErrorResponse,
+ PoolingChatRequest,
+ PoolingRequest, PoolingResponse,
+ PoolingResponseData, UsageInfo)
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.logger import init_logger
+from vllm.outputs import PoolingOutput, PoolingRequestOutput
+from vllm.utils import merge_async_iterators
+
+logger = init_logger(__name__)
+
+
+def _get_data(
+ output: PoolingOutput,
+ encoding_format: Literal["float", "base64"],
+) -> Union[List[float], str]:
+ if encoding_format == "float":
+ return output.data.tolist()
+ elif encoding_format == "base64":
+ # Force to use float32 for base64 encoding
+ # to match the OpenAI python client behavior
+ pooling_bytes = np.array(output.data, dtype="float32").tobytes()
+ return base64.b64encode(pooling_bytes).decode("utf-8")
+
+ assert_never(encoding_format)
+
+
+class OpenAIServingPooling(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ chat_template: Optional[str],
+ chat_template_content_format: ChatTemplateContentFormatOption,
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger)
+
+ self.chat_template = chat_template
+ self.chat_template_content_format: Final = chat_template_content_format
+
+ async def create_pooling(
+ self,
+ request: PoolingRequest,
+ raw_request: Optional[Request] = None,
+ ) -> Union[PoolingResponse, ErrorResponse]:
+ """
+ See https://platform.openai.com/docs/api-reference/embeddings/create
+ for the API specification. This API mimics the OpenAI Embedding API.
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ encoding_format = request.encoding_format
+ if request.dimensions is not None:
+ return self.create_error_response(
+ "dimensions is currently not supported")
+
+ model_name = request.model
+ request_id = f"pool-{self._base_request_id(raw_request)}"
+ created_time = int(time.time())
+
+ truncate_prompt_tokens = None
+
+ if request.truncate_prompt_tokens is not None:
+ if request.truncate_prompt_tokens <= self.max_model_len:
+ truncate_prompt_tokens = request.truncate_prompt_tokens
+ else:
+ return self.create_error_response(
+ "truncate_prompt_tokens value is "
+ "greater than max_model_len."
+ " Please, select a smaller truncation size.")
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ if prompt_adapter_request is not None:
+ raise NotImplementedError("Prompt adapter is not supported "
+ "for pooling models")
+
+ if isinstance(request, PoolingChatRequest):
+ (
+ _,
+ request_prompts,
+ engine_prompts,
+ ) = await self._preprocess_chat(
+ request,
+ tokenizer,
+ request.messages,
+ chat_template=request.chat_template or self.chat_template,
+ chat_template_content_format=self.
+ chat_template_content_format,
+ # In pooling requests, we are not generating tokens,
+ # so there is no need to append extra tokens to the input
+ add_generation_prompt=False,
+ continue_final_message=False,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ else:
+ (request_prompts,
+ engine_prompts) = await self._preprocess_completion(
+ request,
+ tokenizer,
+ request.input,
+ truncate_prompt_tokens=truncate_prompt_tokens,
+ add_special_tokens=request.add_special_tokens,
+ )
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
+ try:
+ pooling_params = request.to_pooling_params()
+
+ for i, engine_prompt in enumerate(engine_prompts):
+ request_id_item = f"{request_id}-{i}"
+
+ self._log_inputs(request_id_item,
+ request_prompts[i],
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ generator = self.engine_client.encode(
+ engine_prompt,
+ pooling_params,
+ request_id_item,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ result_generator = merge_async_iterators(*generators)
+
+ num_prompts = len(engine_prompts)
+
+ # Non-streaming response
+ final_res_batch: List[Optional[PoolingRequestOutput]]
+ final_res_batch = [None] * num_prompts
+ try:
+ async for i, res in result_generator:
+ final_res_batch[i] = res
+
+ assert all(final_res is not None for final_res in final_res_batch)
+
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
+ final_res_batch)
+
+ response = self.request_output_to_pooling_response(
+ final_res_batch_checked,
+ request_id,
+ created_time,
+ model_name,
+ encoding_format,
+ )
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ return response
+
+ def request_output_to_pooling_response(
+ self,
+ final_res_batch: List[PoolingRequestOutput],
+ request_id: str,
+ created_time: int,
+ model_name: str,
+ encoding_format: Literal["float", "base64"],
+ ) -> PoolingResponse:
+ items: List[PoolingResponseData] = []
+ num_prompt_tokens = 0
+
+ for idx, final_res in enumerate(final_res_batch):
+ item = PoolingResponseData(
+ index=idx,
+ data=_get_data(final_res.outputs, encoding_format),
+ )
+ prompt_token_ids = final_res.prompt_token_ids
+
+ items.append(item)
+ num_prompt_tokens += len(prompt_token_ids)
+
+ usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ total_tokens=num_prompt_tokens,
+ )
+
+ return PoolingResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ data=items,
+ usage=usage,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py
new file mode 100644
index 0000000000000000000000000000000000000000..366df71217e9101c6d7b381bdf18efeef64752ff
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py
@@ -0,0 +1,208 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
+
+from fastapi import Request
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
+ RerankRequest, RerankResponse,
+ RerankResult, RerankUsage)
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.inputs.data import TokensPrompt
+from vllm.logger import init_logger
+from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
+from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
+from vllm.utils import make_async, merge_async_iterators
+
+logger = init_logger(__name__)
+
+
+class JinaAIServingRerank(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger)
+
+ async def do_rerank(
+ self,
+ request: RerankRequest,
+ raw_request: Optional[Request] = None
+ ) -> Union[RerankResponse, ErrorResponse]:
+ """
+ Rerank API based on JinaAI's rerank API; implements the same
+ API interface. Designed for compatibility with off-the-shelf
+ tooling, since this is a common standard for reranking APIs
+
+ See example client implementations at
+ https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
+ numerous clients use this standard.
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ model_name = request.model
+ request_id = f"rerank-{self._base_request_id(raw_request)}"
+ truncate_prompt_tokens = request.truncate_prompt_tokens
+ query = request.query
+ documents = request.documents
+ request_prompts = []
+ engine_prompts = []
+ top_n = request.top_n if request.top_n > 0 else len(documents)
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ if prompt_adapter_request is not None:
+ raise NotImplementedError("Prompt adapter is not supported "
+ "for scoring models")
+
+ if isinstance(tokenizer, MistralTokenizer):
+ raise ValueError(
+ "MistralTokenizer not supported for cross-encoding")
+
+ if not self.model_config.is_cross_encoder:
+ raise ValueError("Model is not cross encoder.")
+
+ if truncate_prompt_tokens is not None and \
+ truncate_prompt_tokens > self.max_model_len:
+ raise ValueError(
+ f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
+ f"is greater than max_model_len ({self.max_model_len})."
+ f" Please, select a smaller truncation size.")
+ for doc in documents:
+ request_prompt = f"{query}{tokenizer.sep_token}{doc}"
+ tokenization_kwargs: Dict[str, Any] = {}
+ if truncate_prompt_tokens is not None:
+ tokenization_kwargs["truncation"] = True
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
+
+ tokenize_async = make_async(tokenizer.__call__,
+ executor=self._tokenizer_executor)
+ prompt_inputs = await tokenize_async(text=query,
+ text_pair=doc,
+ **tokenization_kwargs)
+
+ input_ids = prompt_inputs["input_ids"]
+ text_token_prompt = \
+ self._validate_input(request, input_ids, request_prompt)
+ engine_prompt = TokensPrompt(
+ prompt_token_ids=text_token_prompt["prompt_token_ids"],
+ token_type_ids=prompt_inputs.get("token_type_ids"))
+
+ request_prompts.append(request_prompt)
+ engine_prompts.append(engine_prompt)
+
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
+
+ try:
+ pooling_params = request.to_pooling_params()
+
+ for i, engine_prompt in enumerate(engine_prompts):
+ request_id_item = f"{request_id}-{i}"
+
+ self._log_inputs(request_id_item,
+ request_prompts[i],
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ generator = self.engine_client.encode(
+ engine_prompt,
+ pooling_params,
+ request_id_item,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+ result_generator = merge_async_iterators(*generators)
+
+ num_prompts = len(engine_prompts)
+
+ # Non-streaming response
+ final_res_batch: List[Optional[PoolingRequestOutput]]
+ final_res_batch = [None] * num_prompts
+
+ try:
+ async for i, res in result_generator:
+ final_res_batch[i] = res
+
+ assert all(final_res is not None for final_res in final_res_batch)
+
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
+ final_res_batch)
+
+ response = self.request_output_to_rerank_response(
+ final_res_batch_checked, request_id, model_name, documents,
+ top_n)
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ return response
+
+ def request_output_to_rerank_response(
+ self, final_res_batch: List[PoolingRequestOutput], request_id: str,
+ model_name: str, documents: List[str],
+ top_n: int) -> RerankResponse:
+ """
+ Convert the output of do_rank to a RerankResponse
+ """
+ results: List[RerankResult] = []
+ num_prompt_tokens = 0
+ for idx, final_res in enumerate(final_res_batch):
+ classify_res = ScoringRequestOutput.from_base(final_res)
+
+ result = RerankResult(
+ index=idx,
+ document=RerankDocument(text=documents[idx]),
+ relevance_score=classify_res.outputs.score,
+ )
+ results.append(result)
+ prompt_token_ids = final_res.prompt_token_ids
+ num_prompt_tokens += len(prompt_token_ids)
+
+ # sort by relevance, then return the top n if set
+ results.sort(key=lambda x: x.relevance_score, reverse=True)
+ if top_n < len(documents):
+ results = results[:top_n]
+
+ return RerankResponse(
+ id=request_id,
+ model=model_name,
+ results=results,
+ usage=RerankUsage(total_tokens=num_prompt_tokens))
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..832aa8516cc359777e5e6326276a656fe7038dab
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py
@@ -0,0 +1,238 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import time
+from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
+
+from fastapi import Request
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
+ ScoreResponse, ScoreResponseData,
+ UsageInfo)
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.inputs.data import TokensPrompt
+from vllm.logger import init_logger
+from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
+from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
+from vllm.utils import make_async, merge_async_iterators
+
+logger = init_logger(__name__)
+
+
+def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
+ str]) -> List:
+ if isinstance(text_1, (str, dict)):
+ # Convert a single prompt to a list.
+ text_1 = [text_1]
+ text_1 = [t for t in text_1]
+
+ if isinstance(text_2, (str, dict)):
+ # Convert a single prompt to a list.
+ text_2 = [text_2]
+ text_2 = [t for t in text_2]
+ if len(text_1) > 1 and len(text_1) != len(text_2):
+ raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
+ if len(text_1) == 0:
+ raise ValueError("At least one text element must be given")
+ if len(text_2) == 0:
+ raise ValueError("At least one text_pair element must be given")
+
+ if len(text_1) == 1:
+ text_1 = text_1 * len(text_2)
+
+ return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
+
+
+class OpenAIServingScores(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger)
+
+ async def create_score(
+ self,
+ request: ScoreRequest,
+ raw_request: Optional[Request] = None,
+ ) -> Union[ScoreResponse, ErrorResponse]:
+ """
+ Score API similar to Sentence Transformers cross encoder
+
+ See https://sbert.net/docs/package_reference/cross_encoder
+ """
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ model_name = request.model
+ request_id = f"score-{self._base_request_id(raw_request)}"
+ created_time = int(time.time())
+ truncate_prompt_tokens = request.truncate_prompt_tokens
+
+ request_prompts = []
+ engine_prompts = []
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ if prompt_adapter_request is not None:
+ raise NotImplementedError("Prompt adapter is not supported "
+ "for scoring models")
+
+ if isinstance(tokenizer, MistralTokenizer):
+ raise ValueError(
+ "MistralTokenizer not supported for cross-encoding")
+
+ if not self.model_config.is_cross_encoder:
+ raise ValueError("Model is not cross encoder.")
+
+ if truncate_prompt_tokens is not None and \
+ truncate_prompt_tokens > self.max_model_len:
+ raise ValueError(
+ f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
+ f"is greater than max_model_len ({self.max_model_len})."
+ f" Please, select a smaller truncation size.")
+
+ input_pairs = make_pairs(request.text_1, request.text_2)
+ for q, t in input_pairs:
+ request_prompt = f"{q}{tokenizer.sep_token}{t}"
+
+ tokenization_kwargs: Dict[str, Any] = {}
+ if truncate_prompt_tokens is not None:
+ tokenization_kwargs["truncation"] = True
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
+
+ tokenize_async = make_async(tokenizer.__call__,
+ executor=self._tokenizer_executor)
+ prompt_inputs = await tokenize_async(text=q,
+ text_pair=t,
+ **tokenization_kwargs)
+
+ input_ids = prompt_inputs["input_ids"]
+ text_token_prompt = \
+ self._validate_input(request, input_ids, request_prompt)
+ engine_prompt = TokensPrompt(
+ prompt_token_ids=text_token_prompt["prompt_token_ids"],
+ token_type_ids=prompt_inputs.get("token_type_ids"))
+
+ request_prompts.append(request_prompt)
+ engine_prompts.append(engine_prompt)
+
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ # Schedule the request and get the result generator.
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
+
+ try:
+ pooling_params = request.to_pooling_params()
+
+ for i, engine_prompt in enumerate(engine_prompts):
+ request_id_item = f"{request_id}-{i}"
+
+ self._log_inputs(request_id_item,
+ request_prompts[i],
+ params=pooling_params,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(raw_request.headers))
+
+ generator = self.engine_client.encode(
+ engine_prompt,
+ pooling_params,
+ request_id_item,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ priority=request.priority,
+ )
+
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ result_generator = merge_async_iterators(*generators)
+
+ num_prompts = len(engine_prompts)
+
+ # Non-streaming response
+ final_res_batch: List[Optional[PoolingRequestOutput]]
+ final_res_batch = [None] * num_prompts
+
+ try:
+ async for i, res in result_generator:
+ final_res_batch[i] = res
+
+ assert all(final_res is not None for final_res in final_res_batch)
+
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
+ final_res_batch)
+
+ response = self.request_output_to_score_response(
+ final_res_batch_checked,
+ request_id,
+ created_time,
+ model_name,
+ )
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ return response
+
+ def request_output_to_score_response(
+ self,
+ final_res_batch: List[PoolingRequestOutput],
+ request_id: str,
+ created_time: int,
+ model_name: str,
+ ) -> ScoreResponse:
+ items: List[ScoreResponseData] = []
+ num_prompt_tokens = 0
+
+ for idx, final_res in enumerate(final_res_batch):
+ classify_res = ScoringRequestOutput.from_base(final_res)
+
+ item = ScoreResponseData(
+ index=idx,
+ score=classify_res.outputs.score,
+ )
+ prompt_token_ids = final_res.prompt_token_ids
+
+ items.append(item)
+ num_prompt_tokens += len(prompt_token_ids)
+
+ usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ total_tokens=num_prompt_tokens,
+ )
+
+ return ScoreResponse(
+ id=request_id,
+ created=created_time,
+ model=model_name,
+ data=items,
+ usage=usage,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c79adf90c8ad13e9afb640c278ebdec9d6c59ff
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py
@@ -0,0 +1,146 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Final, List, Optional, Union
+
+from fastapi import Request
+
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
+from vllm.entrypoints.logger import RequestLogger
+# yapf conflicts with isort for this block
+# yapf: disable
+from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
+ DetokenizeResponse,
+ ErrorResponse,
+ TokenizeChatRequest,
+ TokenizeRequest,
+ TokenizeResponse)
+# yapf: enable
+from vllm.entrypoints.openai.serving_engine import OpenAIServing
+from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class OpenAIServingTokenization(OpenAIServing):
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ model_config: ModelConfig,
+ models: OpenAIServingModels,
+ *,
+ request_logger: Optional[RequestLogger],
+ chat_template: Optional[str],
+ chat_template_content_format: ChatTemplateContentFormatOption,
+ ) -> None:
+ super().__init__(engine_client=engine_client,
+ model_config=model_config,
+ models=models,
+ request_logger=request_logger)
+
+ self.chat_template = chat_template
+ self.chat_template_content_format: Final = chat_template_content_format
+
+ async def create_tokenize(
+ self,
+ request: TokenizeRequest,
+ raw_request: Request,
+ ) -> Union[TokenizeResponse, ErrorResponse]:
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request_id = f"tokn-{self._base_request_id(raw_request)}"
+
+ try:
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ if isinstance(request, TokenizeChatRequest):
+ (
+ _,
+ request_prompts,
+ engine_prompts,
+ ) = await self._preprocess_chat(
+ request,
+ tokenizer,
+ request.messages,
+ chat_template=request.chat_template or self.chat_template,
+ chat_template_content_format=self.
+ chat_template_content_format,
+ add_generation_prompt=request.add_generation_prompt,
+ continue_final_message=request.continue_final_message,
+ chat_template_kwargs=request.chat_template_kwargs,
+ add_special_tokens=request.add_special_tokens,
+ )
+ else:
+ (request_prompts,
+ engine_prompts) = await self._preprocess_completion(
+ request,
+ tokenizer,
+ request.prompt,
+ add_special_tokens=request.add_special_tokens,
+ )
+ except ValueError as e:
+ logger.exception("Error in preprocessing prompt inputs")
+ return self.create_error_response(str(e))
+
+ input_ids: List[int] = []
+ for i, engine_prompt in enumerate(engine_prompts):
+ self._log_inputs(request_id,
+ request_prompts[i],
+ params=None,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ # Silently ignore prompt adapter since it does not affect
+ # tokenization (Unlike in Embeddings API where an error is raised)
+
+ input_ids.extend(engine_prompt["prompt_token_ids"])
+
+ return TokenizeResponse(tokens=input_ids,
+ count=len(input_ids),
+ max_model_len=self.max_model_len)
+
+ async def create_detokenize(
+ self,
+ request: DetokenizeRequest,
+ raw_request: Request,
+ ) -> Union[DetokenizeResponse, ErrorResponse]:
+ error_check_ret = await self._check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request_id = f"tokn-{self._base_request_id(raw_request)}"
+
+ (
+ lora_request,
+ prompt_adapter_request,
+ ) = self._maybe_get_adapters(request)
+
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
+ self._log_inputs(request_id,
+ request.tokens,
+ params=None,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request)
+
+ # Silently ignore prompt adapter since it does not affect tokenization
+ # (Unlike in Embeddings API where an error is raised)
+
+ prompt_input = await self._tokenize_prompt_input_async(
+ request,
+ tokenizer,
+ request.tokens,
+ )
+ input_text = prompt_input["prompt"]
+
+ return DetokenizeResponse(prompt=input_text)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af37871d57c8afb26a0c534db4d98938f1b8aa9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py
@@ -0,0 +1,59 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import functools
+
+from fastapi import Request
+
+
+async def listen_for_disconnect(request: Request) -> None:
+ """Returns if a disconnect message is received"""
+ while True:
+ message = await request.receive()
+ if message["type"] == "http.disconnect":
+ break
+
+
+def with_cancellation(handler_func):
+ """Decorator that allows a route handler to be cancelled by client
+ disconnections.
+
+ This does _not_ use request.is_disconnected, which does not work with
+ middleware. Instead this follows the pattern from
+ starlette.StreamingResponse, which simultaneously awaits on two tasks- one
+ to wait for an http disconnect message, and the other to do the work that we
+ want done. When the first task finishes, the other is cancelled.
+
+ A core assumption of this method is that the body of the request has already
+ been read. This is a safe assumption to make for fastapi handlers that have
+ already parsed the body of the request into a pydantic model for us.
+ This decorator is unsafe to use elsewhere, as it will consume and throw away
+ all incoming messages for the request while it looks for a disconnect
+ message.
+
+ In the case where a `StreamingResponse` is returned by the handler, this
+ wrapper will stop listening for disconnects and instead the response object
+ will start listening for disconnects.
+ """
+
+ # Functools.wraps is required for this wrapper to appear to fastapi as a
+ # normal route handler, with the correct request type hinting.
+ @functools.wraps(handler_func)
+ async def wrapper(*args, **kwargs):
+
+ # The request is either the second positional arg or `raw_request`
+ request = args[1] if len(args) > 1 else kwargs["raw_request"]
+
+ handler_task = asyncio.create_task(handler_func(*args, **kwargs))
+ cancellation_task = asyncio.create_task(listen_for_disconnect(request))
+
+ done, pending = await asyncio.wait([handler_task, cancellation_task],
+ return_when=asyncio.FIRST_COMPLETED)
+ for task in pending:
+ task.cancel()
+
+ if handler_task in done:
+ return handler_task.result()
+ return None
+
+ return wrapper
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dcdedfffde30e05df58e99e64583f34c9862ebf
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22a8374fcf728ae62a7e82040720367a827b939e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9229e9c092770f53168da44dbd8b025afb33f03c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c89dcbacc9888f124e3d5b20383c21901b3fcbe
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..587570e96f96f5e2e72abb1d268de82b4d3a5242
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..328131d687cbf1b688f2efa32b1de71fb95a48f0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87366baa902609beffdab4763e4cd2ef870b6a29
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py b/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..f379ec1682a3c99eeecbda7a08b6f9097882c920
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py
@@ -0,0 +1,77 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import base64
+from io import BytesIO
+from pathlib import Path
+
+import numpy as np
+import numpy.typing as npt
+
+from vllm.inputs.registry import InputContext
+from vllm.utils import PlaceholderModule
+
+from .base import MediaIO, MultiModalPlugin
+from .inputs import AudioItem, ModalityData, MultiModalKwargs
+
+try:
+ import librosa
+except ImportError:
+ librosa = PlaceholderModule("librosa") # type: ignore[assignment]
+
+try:
+ import soundfile
+except ImportError:
+ soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
+
+
+class AudioPlugin(MultiModalPlugin):
+ """Plugin for audio data."""
+
+ def get_data_key(self) -> str:
+ return "audio"
+
+ def _default_input_mapper(
+ self,
+ ctx: InputContext,
+ data: ModalityData[AudioItem],
+ **mm_processor_kwargs,
+ ) -> MultiModalKwargs:
+ raise NotImplementedError("There is no default audio input mapper")
+
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
+ raise NotImplementedError(
+ "There is no default maximum multimodal tokens")
+
+
+def resample_audio(
+ audio: npt.NDArray[np.floating],
+ *,
+ orig_sr: float,
+ target_sr: float,
+) -> npt.NDArray[np.floating]:
+ return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
+
+
+class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
+
+ def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
+ return librosa.load(BytesIO(data), sr=None)
+
+ def load_base64(
+ self,
+ media_type: str,
+ data: str,
+ ) -> tuple[npt.NDArray, float]:
+ return self.load_bytes(base64.b64decode(data))
+
+ def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
+ return librosa.load(filepath, sr=None)
+
+ def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
+ audio, sr = media
+
+ with BytesIO() as buffer:
+ soundfile.write(buffer, audio, sr, format="WAV")
+ data = buffer.getvalue()
+
+ return base64.b64encode(data).decode('utf-8')
diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py b/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48d07ba365ba62a56c99842726d17a3261cc15c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/multimodal/base.py
@@ -0,0 +1,463 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
+ Optional, Sequence, Tuple, Type, TypeVar, Union)
+
+from torch import nn
+
+from vllm.inputs import InputContext
+from vllm.logger import init_logger
+from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
+ resolve_mm_processor_kwargs)
+
+if TYPE_CHECKING:
+ from vllm.config import ModelConfig
+ from vllm.sequence import SequenceGroupMetadata
+
+from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
+ PlaceholderRange)
+
+logger = init_logger(__name__)
+
+MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
+ MultiModalKwargs]
+"""
+Return a dictionary to be passed as keyword arguments to
+:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
+and processors in HuggingFace Transformers.
+
+If the data is not supported, throw :exc:`TypeError`.
+"""
+
+MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
+"""
+Calculate the maximum number of multimodal tokens input to the language
+model. This does not include tokens that correspond to the input text.
+"""
+
+_T = TypeVar("_T")
+N = TypeVar("N", bound=Type[nn.Module])
+
+
+class MultiModalPlugin(ABC):
+ """
+ Base class that defines data processing logic for a specific modality.
+
+ In particular, we adopt a registry pattern to dispatch data processing
+ according to the model being used (considering that different models may
+ process the same data differently). This registry is in turn used by
+ :class:`~MultiModalRegistry` which acts at a higher level
+ (i.e., the modality of the data).
+ """
+
+ def __init__(self) -> None:
+ self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
+ self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
+
+ @abstractmethod
+ def get_data_key(self) -> str:
+ """
+ Get the data key corresponding to the modality.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _default_input_mapper(
+ self,
+ ctx: InputContext,
+ data: ModalityData[Any],
+ **mm_processor_kwargs,
+ ) -> MultiModalKwargs:
+ """
+ Return a dictionary to be passed as keyword arguments to
+ :meth:`~torch.nn.Module.forward`. This is similar in concept to
+ tokenizers and processors in HuggingFace Transformers.
+
+ If the data is not supported, throw :exc:`TypeError`.
+ """
+ raise NotImplementedError
+
+ def register_input_mapper(
+ self,
+ mapper: Optional[MultiModalInputMapper] = None,
+ ):
+ """
+ Register an input mapper to a model class.
+
+ When the model receives input data that matches the modality served by
+ this plugin (see :meth:`get_data_key`), the provided function is
+ invoked to transform the data into a dictionary of model inputs.
+
+ If `None` is provided, then the default input mapper is used instead.
+ """
+
+ def wrapper(model_cls: N) -> N:
+ if self._input_mappers.contains(model_cls, strict=True):
+ logger.warning(
+ "Model class %s already has an input mapper "
+ "registered to %s. It is overwritten by the new one.",
+ model_cls,
+ self,
+ )
+
+ self._input_mappers[model_cls] = (mapper
+ or self._default_input_mapper)
+
+ return model_cls
+
+ return wrapper
+
+ def map_input(
+ self,
+ model_config: "ModelConfig",
+ data: ModalityData[Any],
+ mm_processor_kwargs: Optional[dict[str, Any]],
+ ) -> MultiModalKwargs:
+ """
+ Transform the data into a dictionary of model inputs using the
+ input mapper registered for that model.
+
+ The model is identified by ``model_config``.
+
+ Raises:
+ TypeError: If the data type is not supported.
+ """
+
+ # Avoid circular import
+ from vllm.model_executor.model_loader import get_model_architecture
+
+ model_cls, _ = get_model_architecture(model_config)
+
+ mapper = self._input_mappers.get(model_cls)
+
+ if mapper is None:
+ raise KeyError(f"No input mapper in {self} is registered for "
+ f"model class {model_cls.__name__}.")
+
+ if mm_processor_kwargs is None:
+ mm_processor_kwargs = {}
+
+ # In the case of the default mapper, we have to get resource
+ # processor through its HuggingFace autoclass; since this goes
+ # through **kwargs, we can't inspect it the same way, so we allow
+ # drop mm_processor_kwargs based on signature inspection
+ # if we're using the default mapper.
+ #
+ # This should be safe in general due to the sanitation, since the
+ # transformers resource should filter unused kwargs anyway.
+ uses_default_mapper = mapper == self._default_input_mapper
+ mm_processor_kwargs = resolve_mm_processor_kwargs(
+ model_config.mm_processor_kwargs,
+ mm_processor_kwargs,
+ callable=mapper,
+ allow_var_kwargs=uses_default_mapper,
+ )
+ return mapper(InputContext(model_config), data, **mm_processor_kwargs)
+
+ @abstractmethod
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
+ """
+ Calculate the maximum number of tokens, corresponding to a single
+ instance of multimodal data, that are passed to the language model.
+ """
+ raise NotImplementedError
+
+ def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
+ if max_mm_tokens < 1:
+ raise ValueError("You should set the number of tokens to a "
+ f"positive integer. Found: {max_mm_tokens}")
+
+ def register_max_multimodal_tokens(
+ self,
+ max_mm_tokens: Optional[MultiModalTokensCalc] = None,
+ ):
+ """
+ Register the maximum number of tokens, corresponding to a single
+ instance of multimodal data, that are passed to the language model
+ for a model class.
+
+ If `None` is provided, then the default calculation is used instead.
+ """
+
+ def wrapper(model_cls: N) -> N:
+ if self._max_mm_tokens.contains(model_cls, strict=True):
+ logger.warning(
+ "Model class %s already calculates maximum number of "
+ "tokens in %s. It is overwritten by the new one.",
+ model_cls,
+ self,
+ )
+
+ if isinstance(max_mm_tokens, int):
+ self._validate_max_multimodal_tokens(max_mm_tokens)
+
+ self._max_mm_tokens[model_cls] = (
+ max_mm_tokens or self._default_max_multimodal_tokens)
+
+ return model_cls
+
+ return wrapper
+
+ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
+ """
+ Get the maximum number of multi-modal tokens
+ for profiling the memory usage of a model.
+
+ If this registry is not applicable to the model, `0` is returned.
+
+ The model is identified by ``model_config``.
+ """
+ # Avoid circular import
+ from vllm.model_executor.model_loader import get_model_architecture
+ from vllm.model_executor.models import supports_multimodal
+
+ model_cls, _ = get_model_architecture(model_config)
+
+ if not supports_multimodal(model_cls):
+ return 0
+
+ max_mm_tokens = self._max_mm_tokens.get(model_cls)
+ if max_mm_tokens is None:
+ return 0
+
+ if callable(max_mm_tokens):
+ mm_processor_kwargs = get_allowed_kwarg_only_overrides(
+ max_mm_tokens, overrides=model_config.mm_processor_kwargs)
+ max_mm_tokens = max_mm_tokens(InputContext(model_config),
+ **mm_processor_kwargs)
+
+ self._validate_max_multimodal_tokens(max_mm_tokens)
+
+ return max_mm_tokens
+
+
+class MultiModalPlaceholderMap:
+ """
+ Relates multi-modal embeddings to their corresponding placeholders.
+ """
+
+ class IndexMap(NamedTuple):
+ src: list[int]
+ dest: list[int]
+
+ src_ranges: list[range]
+ """
+ The indices of the multi-modal embeddings that will replace the
+ corresponding placeholder embeddings pointed to by ``dest_ranges``.
+ """
+
+ src_len: int
+ """
+ The total number of flattened multi-modal embeddings.
+ """
+
+ dest_ranges: list[range]
+ """
+ The indices of the placeholder embeddings that will be replaced by the
+ multimodal embeddings.
+ """
+
+ dest_len: int
+ """
+ The total number of embeddings in the destination tensor.
+ """
+
+ def __init__(self):
+ self.src_ranges = []
+ self.src_len = 0
+ self.dest_ranges = []
+ self.dest_len = 0
+
+ @classmethod
+ def from_seq_group(
+ cls, seq_group: "SequenceGroupMetadata", positions: range
+ ) -> Tuple[Optional[MultiModalDataDict], dict[str,
+ "MultiModalPlaceholderMap"]]:
+ """
+ Returns the multi-modal items that intersect with the portion of a
+ prompt (``seq_group``) represented by ``positions``, as well as a
+ ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
+ vectors to their corresponding placeholders.
+
+ Examples:
+
+ .. code-block::
+
+ Prompt: |AAAA BBBB What's in these images?|
+ Positions: |.................................|
+
+ images = [A, B]
+ src_ranges = [(0, 4), (4, 8)]
+ dest_ranges = [(0, 4), (5, 9)]
+
+ Prompt: |AAAA BBBB What's in these images?|
+ Positions: | ..... |
+
+ images = [A, B]
+ src_ranges = [(2, 4), (4, 6)]
+ dest_ranges = [(0, 2), (3, 5)]
+
+ Prompt: |AAAA BBBB What's in these images?|
+ Positions: | ......... |
+
+ images = [B]
+ src_ranges = [(0, 4)]
+ dest_ranges = [(0, 4)]
+
+ Prompt: |AAAA BBBB What's in these images?|
+ Positions: | .......................|
+
+ images = []
+ src_ranges = []
+ dest_ranges = []
+ """
+ seq_mm_data = seq_group.multi_modal_data
+ seq_mm_placeholders = seq_group.multi_modal_placeholders
+
+ if not seq_mm_data or not seq_mm_placeholders:
+ return seq_mm_data, {}
+
+ # For merged processor, we directly use mm_kwargs as mm_data
+ if isinstance(seq_mm_data, MultiModalKwargs):
+ placeholder_maps = dict[str, MultiModalPlaceholderMap]()
+
+ for modality, placeholders in seq_mm_placeholders.items():
+ placeholder_map = MultiModalPlaceholderMap()
+
+ if positions:
+ placeholder_map.append_items_from_seq_group(
+ positions,
+ # Dummy, since we don't care about intersecting items
+ [None] * len(placeholders),
+ placeholders,
+ )
+
+ placeholder_maps[modality] = placeholder_map
+
+ return seq_mm_data, placeholder_maps
+
+ mm_data = {**seq_mm_data}
+ placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
+ MultiModalPlaceholderMap)
+
+ for modality, placeholders in seq_mm_placeholders.items():
+ mm_items = mm_data.pop(modality)
+ if not isinstance(mm_items, list):
+ mm_items = [mm_items]
+
+ if positions:
+ intersecting_items = placeholder_maps[modality] \
+ .append_items_from_seq_group(
+ positions,
+ mm_items,
+ placeholders,
+ )
+
+ if intersecting_items:
+ mm_data[modality] = intersecting_items
+
+ return mm_data, placeholder_maps
+
+ def append_items_from_seq_group(
+ self,
+ positions: range,
+ multi_modal_items: list[_T],
+ multi_modal_placeholders: Sequence[PlaceholderRange],
+ ) -> list[_T]:
+ """
+ Adds the multi-modal items that intersect ```positions`` to this
+ placeholder map and returns the intersecting items.
+ """
+ intersecting_items = []
+
+ if len(multi_modal_items) != len(multi_modal_placeholders):
+ raise ValueError(
+ "Multi-modal placeholders and items must have the same length."
+ )
+ for placeholder_dict, mm_item in zip(multi_modal_placeholders,
+ multi_modal_items):
+ placeholder = range(
+ placeholder_dict["offset"],
+ placeholder_dict["offset"] + placeholder_dict["length"],
+ )
+ intersection = range(
+ max(positions.start, placeholder.start),
+ min(positions.stop, placeholder.stop),
+ )
+
+ if not intersection:
+ # Skip this multi-modal item.
+ continue
+
+ token_embedding_range = range(
+ intersection.start - positions.start,
+ intersection.stop - positions.start,
+ )
+
+ multimodal_embedding_range = range(
+ intersection.start - placeholder.start + self.src_len,
+ intersection.stop - placeholder.start + self.src_len,
+ )
+
+ intersecting_items.append(mm_item)
+ self.dest_ranges.append(token_embedding_range)
+ self.src_ranges.append(multimodal_embedding_range)
+ self.src_len += len(placeholder)
+
+ self.dest_len += len(positions)
+ return intersecting_items
+
+ def extend(self, other: "MultiModalPlaceholderMap"):
+ """
+ Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
+ instance based on the source and destination tensors being
+ concatenated.
+ """
+
+ self.src_ranges.extend(
+ range(self.src_len + r.start, self.src_len + r.stop)
+ for r in other.src_ranges)
+ self.src_len += other.src_len
+ self.dest_ranges.extend(
+ range(self.dest_len + r.start, self.dest_len + r.stop)
+ for r in other.dest_ranges)
+ self.dest_len += other.dest_len
+
+ def index_map(self) -> "IndexMap":
+ """
+ Finalizes the placeholder map into lists of indices that can be used to
+ index the source and destination tensors.
+ """
+
+ src_indices = [i for r in self.src_ranges for i in r]
+ dest_indices = [i for r in self.dest_ranges for i in r]
+
+ if len(src_indices) != len(dest_indices):
+ raise ValueError(
+ f"The number of source ({len(src_indices)}) and destination "
+ f"indices ({len(dest_indices)}) must be the same.")
+
+ return MultiModalPlaceholderMap.IndexMap(src=src_indices,
+ dest=dest_indices)
+
+
+class MediaIO(ABC, Generic[_T]):
+
+ @abstractmethod
+ def load_bytes(self, data: bytes) -> _T:
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_base64(self, media_type: str, data: str) -> _T:
+ """
+ List of media types:
+ https://www.iana.org/assignments/media-types/media-types.xhtml
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_file(self, filepath: Path) -> _T:
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1f256106889c11c5316ad08a164588d139dc1c1
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..588f098bf16ab7733f75e03c2931b31ff4dc5490
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f2116ac60728b052602ad23b4846a6a438ecb9c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0f956c79818fff834c6c21bd278d0a7dea0e28b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/draft_model_runner.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fb28c8c22f9646cce2096f0876e9e961173050d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/interfaces.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82b1c11633029b0762fc125c7d1a7ab43240be15
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/medusa_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f09a75eb579000d92e1db51ebffbdbb743a6e042
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/metrics.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa3bf8ea708e2b4a71e15ef051f18bdef7f90831
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7a693ea92af46d2b4b739da83318433fe23cc1d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/mqa_scorer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..818ca506c13fb927e7636afb04376de86a9a90bf
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/multi_step_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..236572f676cde82037968ac43923cf3c3c5243d2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/ngram_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75e297b6b73ea26fc0d584c0973373c56f082ffe
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c943e829eaf96434b6d103f37200b47292a23f0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b403889fe3a3da9f40a1ce104aa5b87947755b6
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e06c137d95a00c96e79f2c3fb382cc7de6cc31b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/target_model_runner.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0832ee69b175bc3f3edf13e557e3f8737c490eab
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/top1_proposer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..266d53f2698e71892c336661f196e5c993db3aaa
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/util.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08ed742a5225186880dc60dc86017bcdc334bd7
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py
@@ -0,0 +1,505 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from array import array
+from itertools import chain, count
+from typing import Iterator, List, Optional, Tuple
+
+import torch
+
+from vllm import SamplingParams
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
+ ExecuteModelRequest, SequenceData,
+ SequenceGroupMetadata, get_all_seq_ids)
+from vllm.spec_decode.interfaces import (SpeculativeProposals,
+ SpeculativeScorer, SpeculativeScores)
+from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
+
+SeqId = int
+TargetSeqId = int
+TokenId = int
+
+DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
+
+
+class BatchExpansionTop1Scorer(SpeculativeScorer):
+ """Implements a speculative scorer that uses batch expansion to get
+ probabilities of speculative tokens according to the scoring model.
+
+ Batch expansion converts a list of sequences and multiple query positions
+ to a new batch of sequences, each with a single query position. This allows
+ for MQA-like scoring in speculative decoding without requiring an MQA
+ kernel.
+
+ It is strictly less efficient than MQA scoring.
+
+ It only supports scoring the top1 proposal tokens of the proposer, instead
+ of topk/tree.
+ """
+
+ @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
+ def score_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ proposals: SpeculativeProposals,
+ ) -> SpeculativeScores:
+ """Score the proposed tokens via the scorer model.
+
+ This converts each input sequence to a set of k+1 target sequences. The
+ target sequences have the unique continuations to be scored and a
+ unique sequence ID that is different from all input sequence ids.
+
+ If a speculative sequence length would exceed the max model length, then
+ no speculation is produced for that sequence.
+
+ Args:
+ execute_model_req: The execution request.
+ proposals: The speculative proposals to score.
+ Returns:
+ SpeculativeScores: The scores of each speculative token, along with
+ which sequences were ignored during scoring.
+ """
+
+ # TODO(cade) perform this on GPU to remove blocking call.
+ proposal_lens_list = proposals.proposal_lens.tolist()
+ proposal_token_ids_list = proposals.proposal_token_ids.tolist()
+
+ # Filter the list to ignore invalid proposals.
+ proposal_token_ids_list_without_skips = [
+ proposals for proposals in proposal_token_ids_list
+ if VLLM_INVALID_TOKEN_ID not in proposals
+ ]
+
+ (spec_indices, non_spec_indices, target_seq_group_metadata_list,
+ num_scoring_tokens) = self._expand_batch(
+ seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
+ proposal_token_ids_list=proposal_token_ids_list_without_skips,
+ proposal_lens_list=proposal_lens_list,
+ )
+
+ target_sampler_output = self._scorer_worker.execute_model(
+ execute_model_req=execute_model_req.clone(
+ seq_group_metadata_list=target_seq_group_metadata_list))
+ assert len(target_sampler_output) == 1, "expected single-step output"
+ target_sampler_output = target_sampler_output[0]
+
+ if not non_spec_indices:
+ # All sequence groups in batch have spec decoding enabled
+ return self._contract_batch_all_spec(
+ target_sampler_output=target_sampler_output,
+ proposals=proposals,
+ )
+ else:
+ # Batch has a mix of spec decode enabled and disabled seq groups
+ return self._contract_batch(
+ execute_model_req.seq_group_metadata_list,
+ target_sampler_output=target_sampler_output,
+ proposals=proposals,
+ num_scoring_tokens=num_scoring_tokens,
+ non_spec_indices=non_spec_indices,
+ spec_indices=spec_indices,
+ k=execute_model_req.num_lookahead_slots,
+ )
+
+ def _expand_batch(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ proposal_token_ids_list: List[List[TokenId]],
+ proposal_lens_list: List[int],
+ ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
+ """Given the input sequences and potentially multiple corresponding
+ proposal tokens, create a new batch where each sequence has a single
+ query token.
+ """
+
+ # vLLM currently only supports proposal lens equal to zero or the batch
+ # proposal len. This adds some complexity (splitting the batch into spec
+ # and non spec sequences) and should be removed in the future. It can be
+ # done by supporting per-sequence proposal lens.
+ (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
+ split_batch_by_proposal_len(
+ seq_group_metadata_list, proposal_lens_list)
+
+ spec_expanded_seqs = self._create_scoring_model_input(
+ seq_group_metadata_list=spec_seqs,
+ proposal_token_ids=proposal_token_ids_list,
+ # NOTE: We determine the seq ids in the expanded batch using the
+ # full seq_group_metadata_list, instead of only spec_seqs.
+ target_seq_ids_iter=self._create_target_seq_id_iterator(
+ seq_ids=get_all_seq_ids(seq_group_metadata_list)),
+ )
+
+ num_scoring_tokens = len(spec_expanded_seqs)
+ # Batch speculative and non-speculative (e.g. chunked prefill) requests
+ # but make sure order is prefill|decode due to backend requirement.
+ target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
+
+ return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
+ num_scoring_tokens)
+
+ def _contract_non_speculative(
+ self, scores: SpeculativeScores,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
+ has_prompt_log: bool) -> SpeculativeScores:
+ """
+ Augment input `scores` with non-speculative requests outputs.
+ This includes decode requests with speculation turned off, as well
+ as prefill requests when `enable_chunked_prefill` is set.
+ For the latter, prefills are further separated into terminal and
+ non-terminal chunks (from which no token is sampled).
+ """
+ if not non_spec_indices:
+ return scores
+
+ if has_prompt_log:
+ # When prompt_logprobs is enabled, prefills yield output token
+ # (and respective prob) in the last entry (prompt|out):
+ # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
+ # With chunked prefill, non-terminal chunks have -1 on each
+ # position: they're still picked, but they're discarded later.
+ seq_meta = seq_group_metadata_list
+ nospec_sizes = torch.tensor([
+ seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
+ for i in non_spec_indices
+ ])
+ nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
+ else:
+ # In this case only sampled tokens are returned, select all.
+ nospec_sampled_token_idxs = list(
+ range(len(non_spec_outputs.token_ids)))
+
+ scores.token_ids[non_spec_indices, :1] = \
+ non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
+ scores.probs[non_spec_indices, :1, :] = \
+ non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
+ scores.logprobs[non_spec_indices, :1, :] = \
+ non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
+ if scores.hidden_states is not None:
+ assert non_spec_outputs.hidden_states is not None
+ scores.hidden_states[non_spec_indices, :1, :] = \
+ non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
+ return scores
+
+ def _contract_batch(
+ self,
+ contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
+ target_sampler_output: SamplerOutput,
+ proposals: SpeculativeProposals, num_scoring_tokens: int,
+ non_spec_indices: List[int], spec_indices: List[int],
+ k: int) -> SpeculativeScores:
+ """Contract the expanded batch back into its original size.
+ This maps the scores of speculative tokens back to their original
+ sequences.
+
+ contracted_bs is the original batch size, and the batch size that the
+ target_sampler_output will be contracted to.
+ """
+ contracted_bs = len(contracted_seq_group_metadata_list)
+ (target_token_ids, target_probs, target_logprobs, target_hidden_states,
+ non_spec_target_token_ids, non_spec_target_probs,
+ non_spec_target_logprobs,
+ non_spec_target_hidden_states) = self._split_scoring_output(
+ target_sampler_output, num_scoring_tokens)
+
+ # Map distinct sequences used to score each token
+ # of shape [batch_size * k + 1] back to [batch_size, k + 1].
+ expanded_batch_size, k = proposals.proposal_token_ids.shape
+
+ # The number of tokens in the expanded batch used for speculation is
+ # equal to the total expanded batch size minus the number of samples for
+ # non-speculative sequences, prefill chunks with no out tokens included
+ non_spec_expanded_bs = len(non_spec_indices)
+ spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
+
+ target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
+ target_probs = target_probs.reshape(*target_token_ids.shape,
+ self._vocab_size)
+ target_logprobs = target_logprobs.reshape(target_probs.shape)
+
+ if target_hidden_states is not None:
+ target_hidden_states = target_hidden_states.reshape(
+ *target_token_ids.shape, target_hidden_states.shape[-1])
+
+ all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
+ fill_value=-1)
+ all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
+ all_logprobs = target_logprobs.new_full(size=all_probs.shape,
+ fill_value=-float("inf"))
+
+ if target_sampler_output.hidden_states is not None:
+ all_hidden_states = target_hidden_states.new_zeros(
+ size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
+ else:
+ all_hidden_states = None
+
+ has_prompt_log = any((sg.sampling_params.prompt_logprobs
+ and sg.sampling_params.prompt_logprobs > 0)
+ for sg in contracted_seq_group_metadata_list)
+ # When prompt logprobs is enabled, lens of returned tensors go from
+ # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
+ # We adjust stride accordingly to get the generated tokens and
+ # their probs, but pass on prompt_logprobs as is.
+ prompt_logprobs = None
+ if (not self._scorer_worker.model_runner.disable_logprobs\
+ and has_prompt_log):
+ prompt_logprobs = [
+ o.prompt_logprobs for o in target_sampler_output.outputs
+ ]
+ elif not has_prompt_log:
+ # When prompt logprobs are not to be returned,
+ # we can ignore non-terminal chunks (no out token).
+ non_spec_indices = [
+ idx for idx in non_spec_indices
+ if contracted_seq_group_metadata_list[idx].do_sample
+ ]
+
+ # "Contract" speculative.
+ if spec_indices:
+ all_tokens[spec_indices] = target_token_ids
+ all_probs[spec_indices] = target_probs
+ all_logprobs[spec_indices] = target_logprobs
+ if all_hidden_states is not None:
+ all_hidden_states[spec_indices] = target_hidden_states
+
+ spec_scores = SpeculativeScores(probs=all_probs,
+ token_ids=all_tokens,
+ logprobs=all_logprobs,
+ hidden_states=all_hidden_states,
+ prompt_logprobs=prompt_logprobs)
+
+ non_spec_outputs = SpeculativeScores(
+ probs=non_spec_target_probs,
+ token_ids=non_spec_target_token_ids,
+ logprobs=non_spec_target_logprobs,
+ hidden_states=non_spec_target_hidden_states)
+ # Contract remaining nonspec entries based on non_spec_indices, if any.
+ return self._contract_non_speculative(
+ spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
+ non_spec_outputs, has_prompt_log)
+
+ def _contract_batch_all_spec(
+ self,
+ target_sampler_output: SamplerOutput,
+ proposals: SpeculativeProposals,
+ ) -> SpeculativeScores:
+ """Contract the expanded batch back into its original size.
+ This maps the scores of speculative tokens back to their original
+ sequences.
+
+ It assumes all sequences in the batch were previously expanded.
+ """
+
+ # Map distinct sequences used to score each token
+ # of shape [batch_size * k + 1] back to [batch_size, k + 1].
+ contracted_bs, k = proposals.proposal_token_ids.shape
+
+ # Reshape tensors to original batch size
+ target_token_ids = target_sampler_output.sampled_token_ids.reshape(
+ contracted_bs, k + 1)
+ target_probs = target_sampler_output.sampled_token_probs.reshape(
+ *target_token_ids.shape, self._vocab_size)
+ target_logprobs = target_sampler_output.logprobs.reshape(
+ target_probs.shape)
+ target_hidden_states = target_sampler_output.hidden_states
+ if target_hidden_states is not None:
+ target_hidden_states = target_hidden_states.reshape(
+ *target_token_ids.shape, target_hidden_states.shape[-1])
+
+ return SpeculativeScores(probs=target_probs,
+ token_ids=target_token_ids,
+ logprobs=target_logprobs,
+ hidden_states=target_hidden_states,
+ prompt_logprobs=None)
+
+ def _create_scoring_model_input(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
+ target_seq_ids_iter: Iterator[TargetSeqId],
+ ) -> List[SequenceGroupMetadata]:
+ """Given the original input sequences and proposed tokens from the draft
+ model, create a list of target sequences that can be used for scoring.
+
+ target_seq_ids_iter provides sequence ids for the expanded batch,
+ fulfilling the requirement that no seq id in the expanded batch is equal
+ to the seq id in the original batch.
+ """
+
+ if not seq_group_metadata_list:
+ return []
+
+ target_seq_group_metadata = list(
+ chain.from_iterable(
+ self._create_target_seq_group_metadata(
+ seq_group_metadata,
+ proposal_token_ids,
+ i,
+ target_seq_ids_iter,
+ ) for i, seq_group_metadata in enumerate(
+ seq_group_metadata_list)))
+
+ return target_seq_group_metadata
+
+ def _create_target_seq_group_metadata(
+ self,
+ input_seq_group_metadata: SequenceGroupMetadata,
+ proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
+ batch_index: int,
+ target_seq_ids_iter: Iterator[TargetSeqId],
+ ) -> List[SequenceGroupMetadata]:
+ """Given an input sequence group metadata and a list of draft tokens,
+ create a list of target SequenceGroupMetadata, one for each
+ token id that needs to be scored.
+
+ Naive speculative decoding requires K target model scores, one for each
+ draft model token. However one can add a bonus token such that if each
+ token is accepted, then a final token may be sampled from the model.
+ This function creates K+1 target SequenceGroupMetadata to take
+ advantage of the bonus token.
+ """
+ assert len(input_seq_group_metadata.seq_data) == 1, (
+ "Beam search "
+ "not supported in speculative decoding")
+ input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
+
+ token_ids_to_score = self._get_token_ids_to_score(
+ proposal_token_ids[batch_index])
+
+ sampling_params = input_seq_group_metadata.sampling_params
+ target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ for i, token_ids in enumerate(token_ids_to_score):
+ target_seq_group_metadata_list.append(
+ self._create_single_target_seq_group_metadata(
+ input_seq_group_metadata,
+ input_seq_id,
+ next(target_seq_ids_iter),
+ token_ids,
+ sampling_params=sampling_params,
+ ))
+
+ return target_seq_group_metadata_list
+
+ @staticmethod
+ def _create_single_target_seq_group_metadata(
+ seq_group_metadata: SequenceGroupMetadata,
+ seq_id: SeqId,
+ target_seq_id: TargetSeqId,
+ token_ids: List[TokenId],
+ sampling_params: SamplingParams,
+ ) -> SequenceGroupMetadata:
+ """Create a single target SequenceGroupMetadata.
+
+ Args:
+ seq_group_metadata: The metadata for the input sequence.
+ seq_id: The input sequence ID.
+ target_seq_id: The corresponding target sequence ID.
+ token_ids: The list of token ids that are to be appended to the
+ input sequence.
+ """
+ seq_data = seq_group_metadata.seq_data[seq_id]
+ prompt_token_ids = seq_data.prompt_token_ids_array
+ new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
+ mrope_position_delta = seq_data.mrope_position_delta
+
+ new_seq_data_dict = {
+ target_seq_id:
+ SequenceData(
+ prompt_token_ids,
+ _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
+ new_output_token_ids),
+ ),
+ }
+ # This is a hack. Technically, spec decoding should compute
+ # num_lookahead slots at one shot, but instead, it expands the batch
+ # and evaluate one by one right now. context_len is seq_len - 1 because
+ # the kv cache is filled by a previous batch in the batch expansion.
+ for data in new_seq_data_dict.values():
+ data.update_num_computed_tokens(data.get_len() - 1)
+ data.mrope_position_delta = mrope_position_delta
+
+ return SequenceGroupMetadata(
+ request_id=seq_group_metadata.request_id,
+ is_prompt=seq_group_metadata.is_prompt,
+ seq_data=new_seq_data_dict,
+ sampling_params=sampling_params,
+ block_tables={
+ target_seq_id: seq_group_metadata.block_tables[seq_id],
+ },
+ lora_request=None,
+ token_chunk_size=1,
+ )
+
+ @staticmethod
+ def _split_scoring_output(
+ sampler_output: SamplerOutput, num_scoring_tokens: int
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
+ Optional[torch.Tensor], torch.Tensor, torch.Tensor,
+ torch.Tensor, Optional[torch.Tensor]]:
+ """Split the target model output into speculative and non-speculative
+ output.
+ """
+
+ # vLLM currently only supports proposal lens equal to zero or the batch
+ # proposal len. This adds some complexity (splitting the batch into spec
+ # and non spec sequences) and should be removed in the future. It can be
+ # done by supporting per-sequence proposal lens.
+ #
+ # First samples are non-speculative, latter samples are from speculative
+ # scoring (prefill|decode order).
+ split_sizes = (sampler_output.sampled_token_ids.numel() -
+ num_scoring_tokens, num_scoring_tokens)
+ (non_spec_probs,
+ spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
+ (non_spec_sampled_tokens, spec_sampled_tokens
+ ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
+ (non_spec_logprobs,
+ spec_logprobs) = sampler_output.logprobs.split(split_sizes)
+
+ if sampler_output.hidden_states is not None:
+ (non_spec_hidden_states, spec_hidden_states
+ ) = sampler_output.hidden_states.split(split_sizes)
+ else:
+ non_spec_hidden_states, spec_hidden_states = None, None
+
+ return (spec_sampled_tokens, spec_probs, spec_logprobs,
+ spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
+ non_spec_logprobs, non_spec_hidden_states)
+
+ @staticmethod
+ def _create_target_seq_id_iterator(
+ seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
+ """Create an iterator for creating target sequence ids.
+ Target sequence ids are distinct from sequence ids because we create a
+ distinct target sequence id for each proposal token to be scored.
+
+ This implementation increments a counter starting at 1 + max of all
+ provided input sequence ids.
+ """
+ return count(start=max(seq_ids) + 1)
+
+ @staticmethod
+ def _get_token_ids_to_score(
+ full_spec_token_ids: List[TokenId] # shape: [k]
+ ) -> List[List[TokenId]]:
+ """Given an int tensor of proposal token ids, return a list of
+ token ids that should be scored.
+
+ Returns k+1 output lists. The additional one is used for generating the
+ bonus token.
+
+ Example:
+ Input: [0, 1, 2, 3] (k=4)
+ Output: (k+1 lists)
+ []
+ [0]
+ [0, 1]
+ [0, 1, 2]
+ [0, 1, 2, 3]
+ """
+ empty_token_ids: List[TokenId] = []
+
+ token_ids_to_score = [empty_token_ids]
+ token_ids_to_score.extend(full_spec_token_ids[:i + 1]
+ for i in range(len(full_spec_token_ids)))
+ return token_ids_to_score
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd085ad77638462535cf2c3d9def11e8647de965
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/interfaces.py
@@ -0,0 +1,98 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import List, Optional, Set, Union
+
+import torch
+
+from vllm.sequence import ExecuteModelRequest, PromptLogprobs
+from vllm.worker.worker_base import WorkerBase
+
+
+@dataclass
+class SpeculativeProposals:
+ """Datastructure used to represent proposal tokens from some proposer. It
+ also tracks how many speculative tokens each sequence has.
+ """
+
+ # Speculative proposal tokens.
+ proposal_token_ids: torch.Tensor
+
+ # Probabilities of the proposal tokens according to the proposer.
+ proposal_probs: torch.Tensor
+
+ # The valid length of each proposal; can be zero.
+ proposal_lens: torch.Tensor
+
+ # A flag to mark that there's no available proposals
+ no_proposals: bool = False
+
+ def __repr__(self):
+ return (f"SpeculativeProposals("
+ f"proposal_token_ids={self.proposal_token_ids}, "
+ f"proposal_probs={self.proposal_probs.shape}, "
+ f"proposal_lens={self.proposal_lens})")
+
+
+@dataclass
+class SpeculativeScores:
+ """Datastructure used to represent the scores of speculative tokens
+ according to the scoring model.
+ """
+
+ # Probabilities of the speculative tokens according to the scoring model.
+ probs: torch.Tensor
+
+ # Log-probabilities of the speculative tokens according to the scoring
+ # model. These values can be used to generate Logprob objects that are
+ # returned to the user.
+ logprobs: torch.Tensor
+
+ # Token ids sampled from the scoring model. Used for speculative bonus
+ # tokens and also non-speculative normal decoding.
+ token_ids: torch.Tensor
+
+ # Optional last hidden states from the scoring model.
+ hidden_states: Optional[torch.Tensor] = None
+
+ # Scoring model may also return logprobs for prompt tokens
+ # for each request, when chunked prefill is enabled.
+ prompt_logprobs: Optional[List[PromptLogprobs]] = None
+
+ def __repr__(self):
+ return (f"SpeculativeScores("
+ f"probs={self.probs.shape}, "
+ f"token_ids={self.token_ids.shape})")
+
+
+class SpeculativeProposer(ABC):
+
+ @abstractmethod
+ def get_spec_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ # If set, this contains all sequence IDs that were assigned
+ # bonus tokens in their last forward pass.
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> SpeculativeProposals:
+ raise NotImplementedError
+
+
+class SpeculativeScorer(ABC):
+
+ def __init__(self, scorer_worker: WorkerBase,
+ device: Union[torch.device, str], vocab_size: int):
+ self._scorer_worker = scorer_worker
+ if isinstance(device, torch.device):
+ device = device.type
+ self._device = device
+ self._vocab_size = vocab_size
+
+ @abstractmethod
+ def score_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ proposals: SpeculativeProposals,
+ ) -> SpeculativeScores:
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b62a988e8b267aed2ecad09cba59eec38808b74
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/medusa_worker.py
@@ -0,0 +1,137 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import weakref
+from typing import List, Optional, Set, Tuple
+
+import torch
+
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
+from vllm.spec_decode.interfaces import SpeculativeProposals
+from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
+from vllm.spec_decode.top1_proposer import Top1Proposer
+from vllm.worker.worker_base import DelegateWorkerBase
+
+
+class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
+ """Worker for Medusa.
+ """
+
+ def __init__(self, *args, **kwargs):
+ DelegateWorkerBase.__init__(self, *args, **kwargs)
+ # Lazy initialization list.
+ self._proposer: Top1Proposer
+
+ def init_device(self):
+ self.worker.init_device()
+
+ self._proposer = Top1Proposer(
+ weakref.proxy(self), # type: ignore[arg-type]
+ self.device,
+ self.vocab_size,
+ max_proposal_len=self.max_model_len,
+ )
+
+ def set_include_gpu_probs_tensor(self):
+ pass
+
+ def set_should_modify_greedy_probs_inplace(self):
+ pass
+
+ @torch.inference_mode()
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ # Unused parameter.
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> Tuple[List[SamplerOutput], bool]:
+ """Run the model forward pass to generate sample_len future tokens.
+ Returns the list of sampler output, one per layer, along with indicator
+ of whether torch tensor in sampler output need to be transposed in
+ latter sampler_output_to_torch logic.
+
+ For medusa worker, this indicator shall be False.
+ """
+ self._raise_if_unsupported(execute_model_req)
+
+ seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+
+ seq_lens, query_lens = self._prepare_input_tensors(
+ seq_group_metadata_list)
+
+ generators = self.model_runner.get_generators(
+ execute_model_req.finished_requests_ids)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list, seq_lens, query_lens, self.device,
+ self.model_runner.pin_memory, generators)
+
+ model_outputs = self.model_runner.model.generate_proposals(
+ previous_hidden_states=execute_model_req.previous_hidden_states.
+ hidden_states,
+ sampling_metadata=sampling_metadata)
+
+ return model_outputs, False
+
+ def _prepare_input_tensors(
+ self,
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ ) -> Tuple[List[int], List[int]]:
+ if not seq_group_metadata_list:
+ return [], []
+
+ seq_lens: List[int] = []
+ query_lens: List[int] = []
+
+ for seq_group_metadata in seq_group_metadata_list:
+ is_prompt = seq_group_metadata.is_prompt
+
+ for seq_data in seq_group_metadata.seq_data.values():
+ seq_data_len = seq_data.get_len()
+ if is_prompt:
+ context_len = seq_data.get_num_computed_tokens()
+ seq_len = min(
+ seq_data_len,
+ context_len + seq_group_metadata.token_chunk_size)
+ seq_lens.append(seq_len)
+ query_lens.append(seq_len - context_len)
+ else:
+ seq_lens.append(seq_data_len)
+ query_lens.append(1)
+
+ return seq_lens, query_lens
+
+ def get_spec_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> SpeculativeProposals:
+ """Produce speculations given an input batch of sequences. The number of
+ speculative tokens per sequence is determined by max_proposal_len.
+ """
+
+ return self._proposer.get_spec_proposals(
+ execute_model_req, seq_ids_with_bonus_token_in_last_step)
+
+ def _raise_if_unsupported(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> None:
+ """MedusaWorker does not yet implement support for cache swap
+ operations or beam search.
+ """
+ if any([
+ execute_model_req.blocks_to_swap_in,
+ execute_model_req.blocks_to_swap_out,
+ execute_model_req.blocks_to_copy
+ ]):
+ raise NotImplementedError(
+ "MedusaWorker does not support cache operations")
+
+ if any(
+ len(seq_group_metadata.seq_data.keys()) != 1
+ for seq_group_metadata in
+ execute_model_req.seq_group_metadata_list):
+ raise NotImplementedError(
+ "MedusaWorker does not support beam search.")
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdaf31895e25dee2e8ec47ff6f0a41f2c208f623
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mlp_speculator_worker.py
@@ -0,0 +1,93 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Optional, Set, Tuple
+
+import torch
+
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
+from vllm.spec_decode.multi_step_worker import MultiStepWorker
+from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
+
+
+class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
+ """Worker for MLPSpeculator models.
+
+ Not currently compatible with LoRA or chunked prefill.
+ """
+
+ @torch.inference_mode()
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
+ # therefore does not need this parameter.
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> Tuple[List[SamplerOutput], bool]:
+ """Run the model forward pass to generate sample_len future tokens.
+ Returns the list of sampler output, one per layer, along with indicator
+ of whether torch tensor in sampler output need to be transposed in
+ latter sampler_output_to_torch logic.
+
+ For mlp spec worker, this indicator shall be True.
+ """
+ self._raise_if_unsupported(execute_model_req)
+
+ seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+
+ (input_tokens, seq_lens,
+ query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
+
+ generators = self.model_runner.get_generators(
+ execute_model_req.finished_requests_ids)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list, seq_lens, query_lens, self.device,
+ self.model_runner.pin_memory, generators)
+
+ model_outputs = self.model_runner.model.generate_proposals(
+ input_ids=input_tokens,
+ previous_hidden_states=execute_model_req.previous_hidden_states.
+ hidden_states,
+ num_predict_tokens=sample_len,
+ sampling_metadata=sampling_metadata)
+
+ assert len(model_outputs) == sample_len
+
+ return model_outputs, True
+
+ def _prepare_input_tensors(
+ self,
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ ) -> Tuple[torch.Tensor, List[int], List[int]]:
+ if not seq_group_metadata_list:
+ return torch.empty(0, device=self.device), [], []
+
+ input_tokens: List[int] = []
+ seq_lens: List[int] = []
+ query_lens: List[int] = []
+
+ for seq_group_metadata in seq_group_metadata_list:
+ is_prompt = seq_group_metadata.is_prompt
+
+ for seq_data in seq_group_metadata.seq_data.values():
+ seq_data_len = seq_data.get_len()
+ if is_prompt:
+ context_len = seq_data.get_num_computed_tokens()
+ seq_len = min(
+ seq_data_len,
+ context_len + seq_group_metadata.token_chunk_size)
+ tokens = seq_data.get_token_ids()[context_len:seq_len]
+ seq_lens.append(seq_len)
+ input_tokens.extend(tokens)
+ query_lens.append(seq_len - context_len)
+ else:
+ seq_lens.append(seq_data_len)
+ input_tokens.append(seq_data.get_last_token_id())
+ query_lens.append(1)
+
+ input_tokens_tensor = torch.tensor(input_tokens,
+ dtype=torch.long,
+ device=self.device)
+ return input_tokens_tensor, seq_lens, query_lens
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6275c460ecefa0aaca2fe2d6be7e3dc90ccd3aa0
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/mqa_scorer.py
@@ -0,0 +1,159 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from vllm.sequence import (ExecuteModelRequest, SequenceData,
+ SequenceGroupMetadata, get_all_seq_ids)
+from vllm.spec_decode.interfaces import (SpeculativeProposals,
+ SpeculativeScorer, SpeculativeScores)
+
+SeqId = int
+TargetSeqId = int
+
+
+class MQAScorer(SpeculativeScorer):
+
+ def score_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ proposals: SpeculativeProposals,
+ ) -> SpeculativeScores:
+ target_seq_group_metadata_list = []
+ target_seq_id_start = max(
+ get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
+ all_proposal_tokens = proposals.proposal_token_ids.tolist()
+ all_proposal_lengths = proposals.proposal_lens.tolist()
+ for i, seq_group_metadata in enumerate(
+ execute_model_req.seq_group_metadata_list):
+ if all_proposal_lengths[i] == 0:
+ # Keep prompt seqs untouched (keep computed_tokens for chunks).
+ target_seq_group_metadata_list.append(seq_group_metadata)
+ continue
+
+ seq_data_dict = seq_group_metadata.seq_data
+ assert len(seq_data_dict) == 1
+ seq_id = next(iter(seq_data_dict.keys()))
+
+ seq_data: SequenceData = seq_data_dict[seq_id]
+ prompt_token_ids = seq_data.get_prompt_token_ids()
+ output_token_ids = seq_data.get_output_token_ids()
+ proposal_token_ids = all_proposal_tokens[
+ i][:all_proposal_lengths[i]]
+ new_output_token_ids = [*output_token_ids, *proposal_token_ids]
+
+ target_seq_id = target_seq_id_start + i
+ new_seq_data = SequenceData.from_seqs(
+ prompt_token_ids=prompt_token_ids,
+ output_token_ids=new_output_token_ids,
+ )
+ new_seq_data.update_num_computed_tokens(
+ len(prompt_token_ids) + len(output_token_ids) - 1)
+
+ # Ensure that the new decode sequence has at least one token.
+ assert len(output_token_ids) >= 1
+ new_seq_data_dict = {target_seq_id: new_seq_data}
+
+ new_seq_group_metadata = SequenceGroupMetadata(
+ request_id=seq_group_metadata.request_id,
+ is_prompt=seq_group_metadata.is_prompt,
+ seq_data=new_seq_data_dict,
+ sampling_params=seq_group_metadata.sampling_params,
+ block_tables={
+ target_seq_id: seq_group_metadata.block_tables[seq_id],
+ },
+ lora_request=None,
+ )
+ target_seq_group_metadata_list.append(new_seq_group_metadata)
+
+ target_sampler_output = self._scorer_worker.execute_model(
+ execute_model_req=execute_model_req.clone(
+ seq_group_metadata_list=target_seq_group_metadata_list))
+
+ target_sampler_output = target_sampler_output[0]
+
+ k = execute_model_req.num_lookahead_slots
+ bs = len(execute_model_req.seq_group_metadata_list)
+ target_token_ids = target_sampler_output.sampled_token_ids
+ target_probs = target_sampler_output.sampled_token_probs
+ target_logprobs = target_sampler_output.logprobs
+ prompt_logprobs = None
+
+ # If all requests have the same number of query tokens, we can avoid
+ # the for loop to build output for better performance.
+ if min(all_proposal_lengths) == k:
+ # Regular decodes only.
+ assert all(not sg.is_prompt
+ for sg in target_seq_group_metadata_list
+ if sg.is_prompt)
+ bs, _ = proposals.proposal_token_ids.shape
+ all_tokens = target_token_ids.reshape(bs, k + 1)
+ all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
+ all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
+ else:
+ # We either have decodes with different lens or prefill+decodes.
+ all_tokens = target_token_ids.new_full(size=(bs, k + 1),
+ fill_value=-1)
+ all_probs = target_probs.new_zeros(*all_tokens.shape,
+ self._vocab_size)
+ all_logprobs = target_logprobs.new_full(size=all_probs.shape,
+ fill_value=-float("inf"))
+ target_token_ids = target_token_ids.flatten()
+
+ # When prompt logprobs is enabled, lens of returned tensors go from
+ # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
+ # We adjust stride accordingly to get the generated tokens and
+ # their probs, but pass on prompt_logprobs as is, since it may be
+ # that n_prompts >> K.
+ has_prompt_log = any((sg.sampling_params.prompt_logprobs
+ and sg.sampling_params.prompt_logprobs > 0)
+ for sg in target_seq_group_metadata_list)
+ # TODO (NickLucche) we should surface `disable_logprobs` as to not
+ # break abstraction to get its value.
+ if (not self._scorer_worker.model_runner.disable_logprobs\
+ and has_prompt_log):
+ prompt_logprobs = [
+ o.prompt_logprobs for o in target_sampler_output.outputs
+ ]
+
+ # Split loop into prefill|decode for readability.
+ start_loc, i = 0, 0
+ while i < len(target_seq_group_metadata_list
+ ) and target_seq_group_metadata_list[i].is_prompt:
+ seq_meta = target_seq_group_metadata_list[i]
+ end_loc = start_loc
+ if has_prompt_log:
+ end_loc += seq_meta.token_chunk_size
+ elif seq_meta.do_sample:
+ end_loc += 1
+
+ # Skip chunks with no output tokens.
+ if seq_meta.do_sample:
+ # Get sampled token (last position in chunk) and its prob.
+ all_tokens[i, 0] = target_token_ids[end_loc - 1]
+ all_probs[i, 0] = target_probs[end_loc - 1]
+ all_logprobs[i, 0] = target_logprobs[end_loc - 1]
+
+ i += 1
+ start_loc = end_loc
+ # Decodes.
+ while i < len(target_seq_group_metadata_list):
+ proposed_len, seq_meta = all_proposal_lengths[
+ i], target_seq_group_metadata_list[i]
+ output_len = proposed_len + 1
+ end_loc = start_loc + output_len
+ all_tokens[
+ i, :output_len] = target_token_ids[start_loc:end_loc]
+ all_probs[i, :output_len] = target_probs[start_loc:end_loc]
+ all_logprobs[
+ i, :output_len] = target_logprobs[start_loc:end_loc]
+ start_loc = end_loc
+ i += 1
+
+ hidden_states = None
+ if target_sampler_output.hidden_states is not None:
+ hidden_states = target_sampler_output.hidden_states.reshape(
+ bs, (k + 1), -1)
+
+ return SpeculativeScores(probs=all_probs,
+ token_ids=all_tokens,
+ logprobs=all_logprobs,
+ hidden_states=hidden_states,
+ prompt_logprobs=prompt_logprobs)
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..5474917a6fab7f436cd2e8905e9777de0abc0727
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/multi_step_worker.py
@@ -0,0 +1,388 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+import weakref
+from typing import Dict, List, Set, Tuple
+
+import torch
+
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.platforms import current_platform
+from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
+ SequenceGroupMetadata)
+
+if current_platform.is_cuda_alike():
+ from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
+
+from vllm.spec_decode.interfaces import (SpeculativeProposals,
+ SpeculativeProposer)
+from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
+from vllm.spec_decode.top1_proposer import Top1Proposer
+from vllm.worker.worker_base import DelegateWorkerBase
+
+
+class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
+ """The MultiStepWorker is equivalent to a Worker except that it allows
+ multiple forward passes in a single call, assuming the scheduler has
+ allocated enough space to store the additional KV. This reduces overhead
+ by invoking the scheduler less.
+
+ The MultiStepWorker does not support cache swap operations, or beam search.
+ Cache swap operations do not require large modifications. On the other hand,
+ beam search requires memory allocations during sequence forks and thus
+ requires more thought for MultiStepWorker support.
+ """
+
+ def __init__(self, *args, **kwargs):
+ DelegateWorkerBase.__init__(self, *args, **kwargs)
+ # Lazy initialization list.
+ self._proposer: SpeculativeProposer
+
+ def init_device(self) -> None:
+ self.worker.init_device()
+ self._proposer = Top1Proposer(
+ weakref.proxy(self), # type: ignore[arg-type]
+ self.device,
+ self.vocab_size,
+ max_proposal_len=self.max_model_len,
+ )
+
+ def set_include_gpu_probs_tensor(self) -> None:
+ # Need include_gpu_probs_tensor for MultiStepWorker
+ self.model_runner.model.sampler.include_gpu_probs_tensor = True
+
+ def set_should_modify_greedy_probs_inplace(self) -> None:
+ self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
+ True)
+
+ @torch.inference_mode()
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> Tuple[List[SamplerOutput], bool]:
+ """Run the model forward pass sample_len times. Returns the list of
+ sampler output, one per model forward pass, along with indicator of
+ whether torch tensor in sampler output need to be transposed in latter
+ sampler_output_to_torch logic.
+
+ For multi step worker, this indicator shall be True.
+ """
+ self._raise_if_unsupported(execute_model_req)
+ # Expand the batch for sequences with a bonus token.
+ # Perform a forward pass on the expanded batch and filter the
+ # response to retain only the original sequences' responses.
+ expanded_request, indices_of_seq_with_bonus_tokens =\
+ self._expand_execute_model_request(
+ execute_model_req, seq_ids_with_bonus_token_in_last_step)
+
+ # Run model sample_len times.
+ model_outputs: List[SamplerOutput] = []
+ if current_platform.is_cuda_alike() and isinstance(
+ self.model_runner, TP1DraftModelRunner
+ ) and self.model_runner.supports_gpu_multi_step(expanded_request):
+ # Here we run the draft_model_runner with multi-step prepare
+ # on the GPU directly
+ expanded_request.num_steps = sample_len
+ self.model_runner.set_indices_of_seq_with_bonus_tokens(
+ indices_of_seq_with_bonus_tokens)
+ model_outputs = self.execute_model(
+ execute_model_req=expanded_request)
+ else:
+ # Here we run multi-step directly, with every step prepared
+ # on the CPU.
+ # TODO: Remove this branch once DraftModelRunner supports TP>1
+ # and other restrictions that are part of DraftModelRunner's
+ # supports_gpu_multi_step(..)
+ for _ in range(sample_len):
+ model_output: List[SamplerOutput] = self.worker.execute_model(
+ execute_model_req=expanded_request)
+ assert (len(model_output) == 1
+ ), "composing multistep workers not supported"
+ model_output = model_output[0]
+
+ self._append_new_tokens(
+ model_output, expanded_request.seq_group_metadata_list,
+ indices_of_seq_with_bonus_tokens)
+ model_outputs.append(model_output)
+
+ # move indices to device to avoid stream sync
+ indices_of_seq_with_bonus_tokens = torch.tensor(
+ indices_of_seq_with_bonus_tokens, device=self.device)
+ filtered_model_outputs = self._filter_model_output(
+ model_outputs, indices_of_seq_with_bonus_tokens)
+ return filtered_model_outputs, True
+
+ @staticmethod
+ def _expand_execute_model_request(
+ execute_model_req: ExecuteModelRequest,
+ seq_with_bonus_token_in_last_step: set,
+ ) -> Tuple[ExecuteModelRequest, List[int]]:
+ """
+ Expands the execute model request based on sequences with bonus
+ tokens.
+
+ For each sequence with a bonus token, this method creates a new
+ sequence without the bonus token and adds it to the execute model
+ request. The original sequence groups are also retained. The indices
+ of the original sequence groups are returned for further processing.
+
+ Args:
+ execute_model_req (ExecuteModelRequest): The original execute
+ model request.
+ seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
+ contain bonus tokens.
+
+ Returns:
+ Tuple[ExecuteModelRequest, List[int]]: The updated execute model
+ request with expanded sequences and a list of indices corresponding
+ to the original sequence groups.
+ """
+ updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ updated_execute_model_req = execute_model_req.clone(
+ updated_seq_group_metadata_list)
+ indices_of_original_sequence_groups = []
+ for seq_group in execute_model_req.seq_group_metadata_list:
+ seq_group_has_bonus_tokens = False
+ for seq_id, _ in seq_group.seq_data.items():
+ # Identify sequences with bonus tokens in the sequence group.
+ if seq_id in seq_with_bonus_token_in_last_step:
+ seq_group_has_bonus_tokens = True
+ break
+ if seq_group_has_bonus_tokens:
+ #Create new sequences without the last bonus token. These new
+ # sequence have the same sequence id as the original sequence.
+ # We create a new sequence group and add them there.
+ updated_seq_group_without_bonus_token = \
+ MultiStepWorker._copy_seq_metadata_excluding_last_token(
+ seq_group, seq_with_bonus_token_in_last_step)
+ updated_seq_group_metadata_list.append(
+ updated_seq_group_without_bonus_token)
+ # Add the original sequence group.
+ updated_seq_group_metadata_list.append(
+ MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
+ # Record the index of the original sequence group.
+ indices_of_original_sequence_groups.append(
+ len(updated_seq_group_metadata_list) - 1)
+
+ updated_execute_model_req.seq_group_metadata_list =\
+ updated_seq_group_metadata_list
+
+ if isinstance(updated_execute_model_req.previous_hidden_states,
+ HiddenStates):
+ updated_execute_model_req.previous_hidden_states\
+ .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
+
+ return updated_execute_model_req, indices_of_original_sequence_groups
+
+ @staticmethod
+ def _filter_model_output(
+ expanded_batch_outputs: List[SamplerOutput],
+ output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
+ """
+ Filters the model output to include only the specified sequence
+ outputs. This method contracts the expanded batch output from the
+ model to retain the outputs of only those sequences indicated by the
+ provided indices.
+
+ Args:
+ expanded_batch_output (List[SamplerOutput]): The expanded output
+ batch from the model.
+ output_indices_to_retain (torch.Tensor): Indices of the model
+ outputs to retain.
+
+ Returns:
+ List[SamplerOutput]: A list containing the filtered model
+ outputs for the specified indices.
+ """
+ return [
+ SamplerOutput(
+ outputs=[
+ expanded_batch_output.outputs[i]
+ for i in output_indices_to_retain
+ ] if len(expanded_batch_output.outputs) > 0 else [],
+ sampled_token_probs=(
+ expanded_batch_output.
+ sampled_token_probs[output_indices_to_retain]
+ if expanded_batch_output.sampled_token_probs is not None
+ else None),
+ logprobs=(
+ expanded_batch_output.logprobs[output_indices_to_retain]
+ if expanded_batch_output.logprobs is not None else None),
+ sampled_token_ids=(expanded_batch_output.
+ sampled_token_ids[output_indices_to_retain]
+ if expanded_batch_output.sampled_token_ids
+ is not None else None))
+ for expanded_batch_output in expanded_batch_outputs
+ ]
+
+ def get_spec_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ seq_ids_with_bonus_token_in_last_step: set,
+ ) -> SpeculativeProposals:
+ """Produce speculations given an input batch of sequences. The number of
+ speculative tokens per sequence is determined by max_proposal_len.
+ """
+ return self._proposer.get_spec_proposals(
+ execute_model_req, seq_ids_with_bonus_token_in_last_step)
+
+ @staticmethod
+ def _append_new_tokens(
+ model_output: List[SamplerOutput],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ indices_of_seq_with_bonus_tokens: List[int]) -> None:
+ """Given model output from a single run, append the tokens to the
+ sequences. This is normally done outside of the worker, but it is
+ required if the worker is to perform multiple forward passes.
+ """
+ count = 0
+ for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
+ zip(seq_group_metadata_list, model_output)):
+ seq_group_metadata.is_prompt = False
+
+ for seq_output in sequence_group_outputs.samples:
+ # NOTE: Beam search is not supported, so we can assume that
+ # parent_seq_id == seq_id.
+ seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
+
+ token_id = seq_output.output_token
+ token_logprob = seq_output.logprobs[token_id]
+ # Determine the actual token ID to be generated,
+ # considering bonus tokens
+ if index != indices_of_seq_with_bonus_tokens[count]:
+ bonus_seq_metadata = seq_group_metadata_list[
+ indices_of_seq_with_bonus_tokens[count]]
+ _, bonus_token_seq_data = next(
+ iter(bonus_seq_metadata.seq_data.items()))
+ token_id = bonus_token_seq_data.output_token_ids[-1]
+ else:
+ count += 1
+
+ seq.append_token_id(token_id, token_logprob.logprob)
+ seq.update_num_computed_tokens(1)
+
+ @staticmethod
+ def _shallow_copy_seq_group_metadata(
+ seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
+ """Copy input data structures to remove side-effects when input data
+ structures are shared with other modules.
+
+ Helpful when the vLLM scheduler runs in the same process as the worker.
+ The alternative is deep-copying (or other form of deep copy); this has
+ performance downsides.
+ """
+ # Shallow-copy the SequenceGroupMetadata. This allows us to
+ # append tokens and change is_prompt without external side-effects.
+ # We must shallow-copy seq_group_metadata as is_prompt could change.
+ new_seq_group_metadata = copy.copy(seq_group_metadata)
+
+ # We must shallow-copy seq_data as we will append token ids
+ new_seq_data: Dict[int, SequenceData] = {}
+ for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
+ new_seq_data[seq_id] = copy.copy(old_seq_data)
+ new_seq_data[seq_id].output_token_ids =\
+ old_seq_data.output_token_ids[:]
+
+ new_seq_group_metadata.seq_data = new_seq_data
+ return new_seq_group_metadata
+
+ @staticmethod
+ def _copy_seq_metadata_excluding_last_token(
+ seq_group_metadata: SequenceGroupMetadata,
+ seq_ids_to_copy: Set[int],
+ ) -> SequenceGroupMetadata:
+ """
+ Creates a shallow copy of the given SequenceGroupMetadata, retaining
+ only the sequence IDs specified in seq_ids_to_copy. For each of these
+ sequence IDs, all output_token_ids except the last one are copied.
+ Sequence IDs not in seq_ids_to_copy are excluded from the copy.
+
+ Parameters:
+ seq_group_metadata (SequenceGroupMetadata): The original sequence
+ group metadata.
+ seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
+ copy.
+
+ Returns:
+ SequenceGroupMetadata: A shallow copy of the sequence group metadata
+ with the specified modifications.
+ """
+ # Shallow-copy the SequenceGroupMetadata.
+ new_seq_group_metadata = copy.copy(seq_group_metadata)
+ # Shallow-copy seq_data and modify the output_token_ids.
+ new_seq_data: Dict[int, SequenceData] = {}
+ for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
+ if (seq_id in seq_ids_to_copy):
+ new_seq_data[seq_id] = copy.copy(old_seq_data)
+ # Copy all the output token ids except the last.
+ # Also reduce num_computed_tokens by 1 since we are not
+ # including the last output token.
+ # NOTE: num_computed_tokens is not directly used by the
+ # speculative decoding workers, as it is only relevant for
+ # chunked prefill, which is disabled for speculative decoding.
+ # However, to maintain consistency in num_computed_tokens,
+ # we update it here.
+ new_seq_data[seq_id].output_token_ids =\
+ old_seq_data.output_token_ids[:-1]
+ new_seq_data[seq_id].update_num_computed_tokens(-1)
+ new_seq_group_metadata.seq_data = new_seq_data
+ return new_seq_group_metadata
+
+ def _assert_enough_kv_space(
+ self, seq_group_metadata_list: List[SequenceGroupMetadata],
+ num_steps: int) -> None:
+ """Assert there are enough physical blocks per sequence to store the
+ current KV plus additional KV from num_steps tokens.
+ """
+ assert self.model_runner.block_size is not None
+ for seq_group_metadata in seq_group_metadata_list:
+ # Only one seq_id is guaranteed because there is no beam search.
+ seq_id = list(seq_group_metadata.seq_data.keys())[0]
+ seq = seq_group_metadata.seq_data[seq_id]
+
+ # After num_steps, the seq len will be the current seq len
+ # plus one token per step.
+ final_seq_len = seq.get_len() + num_steps
+
+ # We will have final_seq_len - 1 KV because vLLM saves KV for a
+ # token in the iteration after the token was generated.
+ required_num_kv_slots = final_seq_len - 1
+
+ # The allocated number of kv slots is the number of allocated blocks
+ # times the number of slots of block.
+ number_physical_blocks = len(
+ seq_group_metadata.block_tables[seq_id])
+ allocated_kv_slots = (number_physical_blocks *
+ self.model_runner.block_size)
+
+ if required_num_kv_slots > allocated_kv_slots:
+ request_id = seq_group_metadata.request_id
+ raise ValueError(
+ "The worker attempted to run "
+ f"{num_steps} times but found insufficient KV space for "
+ f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
+ f"{required_num_kv_slots=}).")
+
+ def _raise_if_unsupported(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> None:
+ """MultiStepWorker does not yet implement support for cache swap
+ operations or beam search.
+ """
+ if any([
+ execute_model_req.blocks_to_swap_in,
+ execute_model_req.blocks_to_swap_out,
+ execute_model_req.blocks_to_copy
+ ]):
+ raise NotImplementedError(
+ "MultiStepWorker does not support cache operations")
+
+ if any(
+ len(seq_group_metadata.seq_data.keys()) != 1
+ for seq_group_metadata in
+ execute_model_req.seq_group_metadata_list):
+ raise NotImplementedError(
+ "MultiStepWorker does not support beam search.")
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..86390c99c2fbced6163eac2374cac7afe681b602
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/ngram_worker.py
@@ -0,0 +1,187 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import weakref
+from typing import List, Optional, Set, Tuple
+
+import torch
+import torch.nn as nn
+
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest
+from vllm.spec_decode.interfaces import SpeculativeProposals
+from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
+from vllm.spec_decode.top1_proposer import Top1Proposer
+
+
+class _DummyModel(nn.Module):
+ pass
+
+
+class NGramWorker(NonLLMProposerWorkerBase):
+ """NGramWorker provides a light drafter without need for model.
+
+ Current NGramWorker only implements prompt lookup decoding,
+ and in future we may also do RAG type drafter and other scenarios
+ which don't rely on LLM model to give proposals.
+ """
+
+ def __init__(self, *args, **kwargs):
+ # Get local_rank/vocab_size from kwargs attribute
+ self.local_rank = kwargs["local_rank"]
+ self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
+ self.device_type = kwargs.get("device_type", "cuda")
+
+ # Lazy initialization list.
+ self._proposer: Top1Proposer
+
+ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
+ ngram_prompt_lookup_max: int):
+ # Search valid candidate window between
+ # ngram_prompt_lookup_min/ngram_prompt_lookup_max
+ self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
+ self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
+
+ def init_device(self):
+ self.device = torch.device(f"{self.device_type}:{self.local_rank}")
+
+ # Current NGramWorker only supports Top1Proposer
+ self._proposer = Top1Proposer(
+ weakref.proxy(self), # type: ignore[arg-type]
+ device=self.device,
+ vocab_size=self.vocab_size,
+ )
+
+ def load_model(self) -> None:
+ pass # Dummy
+
+ def get_model(self) -> nn.Module:
+ return _DummyModel()
+
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ # Unused parameter. NGramWorker does not use the KV Cache and
+ # therefore does not need this parameter.
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
+ """NGram match algo to pick proposal candidate. Returns the list of
+ sampler output, one per SequenceGroupMetadata.
+
+ For ngram worker, we already done needed transposed internal, so the
+ indicator pass to sampler_output_to_torch shall be False.
+ """
+ self._raise_if_unsupported(execute_model_req)
+
+ has_spec_out = False
+ token_id_list: List[Optional[torch.Tensor]] = []
+ token_prob_list: List[Optional[torch.Tensor]] = []
+ for idx, seq_group_metadata in enumerate(
+ execute_model_req.seq_group_metadata_list):
+ seq_data = next(iter(seq_group_metadata.seq_data.values()))
+
+ seq_len = seq_data.get_len()
+ # When seq_len is less than 3072 (3K), we use CPU to perform
+ # the ngram match. Otherwise, we use the device specified in
+ # the model config (normally GPU). 3072 is a rough threshold
+ # based on profiling on H100, and it can be adjusted based
+ # on the actual performance on different hardware.
+ cur_device = "cpu" if seq_len < 3072 else self.device
+ input_ids = torch.as_tensor(seq_data.get_token_ids(),
+ dtype=torch.long,
+ device=cur_device)
+ input_length = seq_data.get_len()
+
+ for ngram_size in range(
+ min(self.ngram_prompt_lookup_max, input_length - 1),
+ self.ngram_prompt_lookup_min - 1,
+ -1,
+ ):
+ ngram_tensor = input_ids[-ngram_size:]
+ if ngram_size == 1:
+ # Do not match itself and do not use unfold and all
+ matches = (input_ids[:-1] == ngram_tensor)
+ else:
+ windows = input_ids.unfold(dimension=0,
+ size=ngram_size,
+ step=1)
+ # Do not match itself
+ matches = (windows[:-1] == ngram_tensor).all(dim=-1)
+
+ # first_match includes "values" (bool), indicating whether
+ # the match is found, and "indices", indicating the index
+ # of the first match.
+ first_match = matches.max(dim=-1)
+ if first_match.values.item():
+ proposal_start_idx = first_match.indices.add_(ngram_size)
+ spec_indices = (
+ proposal_start_idx).repeat(sample_len) + torch.arange(
+ sample_len, device=cur_device)
+ spec_indices.clamp_(max=input_ids.shape[-1] - 1)
+ res = input_ids.gather(dim=-1,
+ index=spec_indices).to(self.device)
+ token_id_list.append(res)
+ token_prob_list.append(
+ torch.nn.functional.one_hot(
+ res,
+ num_classes=self.vocab_size).to(torch.float32))
+ has_spec_out = True
+ break
+ else:
+ token_id_list.append(None)
+ token_prob_list.append(None)
+
+ if not has_spec_out:
+ return None, False
+
+ outputs: List[Optional[SamplerOutput]] = []
+ for idx in range(len(execute_model_req.seq_group_metadata_list)):
+ if token_id_list[idx] is None:
+ outputs.append(None)
+ else:
+ outputs.append(
+ SamplerOutput(
+ outputs=None,
+ sampled_token_probs=token_prob_list[idx],
+ logprobs=torch.zeros((sample_len, self.vocab_size),
+ dtype=torch.float32,
+ device=self.device),
+ sampled_token_ids=token_id_list[idx],
+ ))
+
+ return outputs, False
+
+ def get_spec_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ # Unused parameter. NGramWorker does not use the KV Cache and
+ # therefore does not need this parameter.
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> SpeculativeProposals:
+ """Produce speculations given an input batch of sequences. The number of
+ speculative tokens per sequence is determined by max_proposal_len.
+ """
+ return self._proposer.get_spec_proposals(
+ execute_model_req, seq_ids_with_bonus_token_in_last_step)
+
+ def _raise_if_unsupported(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> None:
+ """NGramWorker does not yet implement support for cache swap
+ operations or beam search.
+ """
+ if any([
+ execute_model_req.blocks_to_swap_in,
+ execute_model_req.blocks_to_swap_out,
+ execute_model_req.blocks_to_copy
+ ]):
+ raise NotImplementedError(
+ "NGramWorker does not support cache operations")
+
+ if any(
+ len(seq_group_metadata.seq_data.keys()) != 1
+ for seq_group_metadata in
+ execute_model_req.seq_group_metadata_list):
+ raise NotImplementedError(
+ "NGramWorker does not support beam search.")
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bebf80fadae5e3e637053f95740340bd6a98f7f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/proposer_worker_base.py
@@ -0,0 +1,58 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Set, Tuple
+
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest
+from vllm.spec_decode.interfaces import SpeculativeProposer
+from vllm.worker.worker_base import LoraNotSupportedWorkerBase
+
+
+class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
+ """Interface for proposer workers"""
+
+ @abstractmethod
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ # A set containing all sequence IDs that were assigned bonus tokens
+ # in their last forward pass. This set is used to backfill the KV cache
+ # with the key-value pairs of the penultimate token in the sequences.
+ # This parameter is only used by the MultiStepWorker, which relies on
+ # the KV cache for token generation. It is not used by workers that
+ # do not utilize the KV cache.
+ seq_ids_with_bonus_token_in_last_step: Set[int]
+ ) -> Tuple[Optional[List[SamplerOutput]], bool]:
+ raise NotImplementedError
+
+ def set_include_gpu_probs_tensor(self) -> None:
+ """Implementation optional"""
+ pass
+
+ def set_should_modify_greedy_probs_inplace(self) -> None:
+ """Implementation optional"""
+ pass
+
+
+class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
+ """Proposer worker which does not use a model with kvcache"""
+
+ def execute_model(
+ self,
+ execute_model_req: Optional[ExecuteModelRequest] = None
+ ) -> List[SamplerOutput]:
+ """get_spec_proposals is used to get the proposals"""
+ return []
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """This is never called on the proposer, only the target model"""
+ raise NotImplementedError
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ pass
+
+ def get_cache_block_size_bytes(self) -> int:
+ return 0
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1466ba5db756d59f9ea4e709d27a343fd0943a7
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/smaller_tp_proposer_worker.py
@@ -0,0 +1,175 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Optional, Set, Tuple
+
+import torch
+import torch.nn as nn
+
+from vllm.distributed.parallel_state import (get_tp_group,
+ init_model_parallel_group,
+ patch_tensor_parallel_group)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest
+from vllm.spec_decode.interfaces import SpeculativeProposals
+from vllm.spec_decode.multi_step_worker import MultiStepWorker
+from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
+
+logger = init_logger(__name__)
+
+
+class _DummyModel(nn.Module):
+ pass
+
+
+class SmallerTpProposerWorker(ProposerWorkerBase):
+ """Class which allows a speculative draft model to run with smaller tensor
+ parallel degree than target model.
+ This reduces the communication overhead of small draft models.
+
+ To implement this feature, this class differs behavior based on is_dummy
+ flag, where dummy means worker that does not participate draft generation.
+ Participating workers use a smaller tp group by patching vLLM's tensor
+ parallel group temporarily during forward passes of draft models.
+ """
+
+ @classmethod
+ def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
+ target_tensor_parallel_size: int):
+ """Wrap the worker in a SmallerTpProposerWorker if necessary.
+ """
+ if draft_tensor_parallel_size == target_tensor_parallel_size:
+ return worker
+
+ # gpu ranks that will generate draft tokens together
+ draft_ranks = list(range(draft_tensor_parallel_size))
+
+ logger.info("Wrapping {%s} in {%s}", type(worker), cls)
+ return cls(worker, draft_ranks)
+
+ def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
+ """Create a SmallerTpProposerWorker.
+
+ Args:
+ worker (MultiStepWorker): an actual worker wrapped with this class
+ draft_ranks (List[int]): if this value is given, only the GPU ranks
+ written in this value participate in draft generation
+ """
+ self._worker = worker
+ self._draft_ranks = draft_ranks
+
+ # init during init_device
+ self._is_dummy = False
+ self._tp_group = None
+
+ def _patch_tensor_parallel_group(self):
+ """Temporarily patch the global tp group state with its own tp group
+ state.
+ """
+ return patch_tensor_parallel_group(self._tp_group)
+
+ def init_device(self) -> None:
+ self._is_dummy = get_tp_group().rank not in self._draft_ranks
+
+ # dummy workers do nothing
+ if self._is_dummy:
+ return
+
+ # creates tp process group containing only a subset of gpu ranks
+ local_rank = get_tp_group().local_rank
+ tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
+ self._tp_group = init_model_parallel_group([self._draft_ranks],
+ local_rank, tp_backend)
+
+ with self._patch_tensor_parallel_group():
+ self._worker.init_device()
+
+ def set_include_gpu_probs_tensor(self) -> None:
+ if self._is_dummy:
+ return
+
+ # Need include_gpu_probs_tensor for multi_step_worker
+ self._worker.set_include_gpu_probs_tensor()
+
+ def set_should_modify_greedy_probs_inplace(self) -> None:
+ if self._is_dummy:
+ return
+
+ self._worker.set_should_modify_greedy_probs_inplace()
+
+ def load_model(self) -> None:
+ if self._is_dummy:
+ return
+
+ with self._patch_tensor_parallel_group():
+ self._worker.load_model()
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ if self._is_dummy:
+ # this case is not used now
+ return -1, -1
+
+ with self._patch_tensor_parallel_group():
+ return self._worker.determine_num_available_blocks()
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ if self._is_dummy:
+ return
+
+ with self._patch_tensor_parallel_group():
+ self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
+ def sampler_output(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ sample_len: int,
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> Tuple[List[SamplerOutput], bool]:
+ # Do not check _is_dummy, as it's always called by get_spec_proposals
+ return self._worker.sampler_output(
+ execute_model_req, sample_len,
+ seq_ids_with_bonus_token_in_last_step)
+
+ def get_spec_proposals(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ seq_ids_with_bonus_token_in_last_step: Set[int],
+ ) -> SpeculativeProposals:
+ """Produce speculations given an input batch of sequences. The number of
+ speculative tokens per sequence is determined by max_proposal_len.
+ """
+ if self._is_dummy:
+ return SpeculativeProposals(None, None, None)
+
+ with self._patch_tensor_parallel_group():
+ return self._worker.get_spec_proposals(
+ execute_model_req, seq_ids_with_bonus_token_in_last_step)
+
+ def get_model(self) -> nn.Module:
+ if self._is_dummy:
+ return _DummyModel()
+
+ with self._patch_tensor_parallel_group():
+ return self._worker.get_model()
+
+ def execute_model(
+ self,
+ execute_model_req: Optional[ExecuteModelRequest] = None
+ ) -> List[SamplerOutput]:
+ if self._is_dummy:
+ return []
+
+ with self._patch_tensor_parallel_group():
+ return self._worker.execute_model(execute_model_req)
+
+ def get_cache_block_size_bytes(self) -> int:
+ if self._is_dummy:
+ # by returning zero, target worker can use the entire kv cache space
+ return 0
+
+ return self._worker.get_cache_block_size_bytes()
+
+ @property
+ def vocab_size(self) -> int:
+ return self._worker.vocab_size
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8653bece8b5a59b616f41cad0bc8f4b201f6ac06
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py
@@ -0,0 +1,1282 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+from collections import defaultdict
+from functools import cached_property
+from typing import Any, Dict, List, Optional, Set, Tuple, Type
+
+import torch
+import torch.nn as nn
+
+from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
+from vllm.distributed.communication_op import broadcast_tensor_dict
+from vllm.logger import init_logger
+from vllm.model_executor.layers.rejection_sampler import RejectionSampler
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.layers.spec_decode_base_sampler import (
+ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
+from vllm.model_executor.layers.typical_acceptance_sampler import (
+ TypicalAcceptanceSampler)
+from vllm.platforms import current_platform
+from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
+ CompletionSequenceGroupOutput, ExecuteModelRequest,
+ HiddenStates, SequenceGroupMetadata,
+ get_all_seq_ids_and_request_ids)
+from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
+
+if current_platform.is_cuda_alike():
+ from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
+
+from vllm.spec_decode.interfaces import (SpeculativeProposals,
+ SpeculativeScorer, SpeculativeScores)
+from vllm.spec_decode.medusa_worker import MedusaWorker
+from vllm.spec_decode.metrics import AsyncMetricsCollector
+from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
+from vllm.spec_decode.mqa_scorer import MQAScorer
+from vllm.spec_decode.multi_step_worker import MultiStepWorker
+from vllm.spec_decode.ngram_worker import NGramWorker
+from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
+from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
+from vllm.spec_decode.target_model_runner import TargetModelRunner
+from vllm.spec_decode.util import (Timer, create_logprobs_output,
+ create_sequence_group_output,
+ get_all_num_logprobs,
+ get_sampled_token_logprobs, nvtx_range,
+ split_batch_by_proposal_len)
+from vllm.utils import resolve_obj_by_qualname
+from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
+
+logger = init_logger(__name__)
+
+
+def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
+ """Helper method that is the entrypoint for Executors which use
+ WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
+ """
+ vllm_config: VllmConfig = kwargs.get("vllm_config")
+ speculative_config: SpeculativeConfig = vllm_config.speculative_config
+ assert speculative_config is not None
+
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
+ raise NotImplementedError("Speculative decoding is currently "
+ "incompatible with pipeline parallelism")
+
+ draft_worker_kwargs = kwargs.copy()
+
+ kwargs["model_runner_cls"] = TargetModelRunner
+ target_worker_config = copy.deepcopy(vllm_config)
+ target_worker_config.parallel_config.worker_cls =\
+ target_worker_config.parallel_config.sd_worker_cls
+ cls = resolve_obj_by_qualname(
+ target_worker_config.parallel_config.worker_cls)
+ target_worker = cls(*args, **kwargs)
+ # Set the disable_logprobs variable in the TargetModelRunner instance
+ # as per its value specified in the SpeculativeConfig.
+ target_worker.model_runner.disable_logprobs =\
+ speculative_config.disable_logprobs
+
+ draft_worker_config = copy.deepcopy(vllm_config)
+ draft_worker_config.model_config = speculative_config.draft_model_config
+ draft_worker_config.quant_config = VllmConfig._get_quantization_config(
+ draft_worker_config.model_config,
+ vllm_config.load_config,
+ )
+ speculative_config.draft_parallel_config.worker_cls =\
+ draft_worker_config.parallel_config.sd_worker_cls
+ draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
+ # TODO allow draft-model specific load config.
+
+ # Override draft-model specific worker args.
+ draft_worker_kwargs.update(
+ vllm_config=draft_worker_config,
+ ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
+ ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
+ )
+
+ spec_decode_worker = SpecDecodeWorker.create_worker(
+ scorer_worker=target_worker,
+ draft_worker_kwargs=draft_worker_kwargs,
+ disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
+ disable_by_batch_size=speculative_config.
+ speculative_disable_by_batch_size,
+ draft_token_acceptance_method=speculative_config.
+ draft_token_acceptance_method,
+ typical_acceptance_sampler_posterior_threshold=speculative_config.
+ typical_acceptance_sampler_posterior_threshold,
+ typical_acceptance_sampler_posterior_alpha=speculative_config.
+ typical_acceptance_sampler_posterior_alpha,
+ disable_logprobs=speculative_config.disable_logprobs,
+ disable_log_stats=speculative_config.disable_log_stats,
+ )
+
+ return spec_decode_worker
+
+
+# Reminder: Please update docs/source/features/compatibility_matrix.md
+# If the feature combo become valid
+class SpecDecodeWorker(LoraNotSupportedWorkerBase):
+ """Worker which implements speculative decoding.
+
+ Speculative decoding reduces decoding per-token latency by using a proposal
+ method, such as a small draft model, to speculate ahead of a larger LLM. The
+ probabilities of the speculative tokens are then determined by the larger
+ LLM, after which some verification routine determines which (if any) of the
+ speculative tokens are accepted by the larger LLM.
+
+ See https://github.com/vllm-project/vllm/pull/2188 and
+ https://github.com/vllm-project/vllm/pull/3103 for more info.
+
+ The current implementation has the following limitations:
+ * Only draft-model proposal is implemented (contributions for more forms are
+ welcome!).
+ * Only top-1 proposal and scoring are implemented. Tree-attention is left as
+ future work.
+ * All sequences in a batch must have the same proposal length, or zero. This
+ can be improved by having per-sequence speculation in the future.
+ * The scoring forward pass is done without an MQA kernel, which is
+ suboptimal especially as the batch size, proposal length, and sequence
+ lengths grow. Contributions to add a MQA scoring are welcome once
+ correctness tests pass.
+ More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
+ """
+
+ @classmethod
+ def create_worker(
+ cls,
+ scorer_worker: WorkerBase,
+ draft_worker_kwargs: Dict[str, Any],
+ disable_mqa_scorer: bool,
+ disable_by_batch_size: Optional[int],
+ draft_token_acceptance_method: str,
+ typical_acceptance_sampler_posterior_threshold: float,
+ typical_acceptance_sampler_posterior_alpha: float,
+ disable_logprobs: bool,
+ disable_log_stats: bool,
+ ) -> "SpecDecodeWorker":
+
+ allow_zero_draft_token_step = True
+ ngram_prompt_lookup_max = (
+ draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
+ ngram_prompt_lookup_min = (
+ draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
+ draft_model_config = draft_worker_kwargs["vllm_config"].model_config
+ draft_parallel_config: ParallelConfig = draft_worker_kwargs[
+ 'vllm_config'].parallel_config
+ if ngram_prompt_lookup_max > 0:
+ draft_worker_kwargs[
+ "device_type"] = scorer_worker.device_config.device.type
+ proposer_worker = NGramWorker(**draft_worker_kwargs)
+ proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
+ ngram_prompt_lookup_max)
+ else:
+ draft_tp = draft_parallel_config.tensor_parallel_size
+ target_tp = scorer_worker.parallel_config.tensor_parallel_size
+
+ if draft_model_config.hf_config.model_type == "mlp_speculator":
+ proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
+ elif draft_model_config.hf_config.model_type == "medusa":
+ proposer_worker = MedusaWorker(**draft_worker_kwargs)
+ else:
+ if draft_tp == 1:
+ if current_platform.is_cuda_alike():
+ draft_worker_kwargs[
+ "model_runner_cls"] = TP1DraftModelRunner
+ else:
+ if draft_model_config.hf_config.model_type == "eagle":
+ raise NotImplementedError(
+ "EAGLE does not support TP > 1 yet")
+
+ allow_zero_draft_token_step = False
+ proposer_worker = MultiStepWorker(**draft_worker_kwargs)
+
+ proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
+ proposer_worker, draft_tp, target_tp)
+
+ logger.info("Configuring SpecDecodeWorker with proposer=%s",
+ type(proposer_worker))
+
+ spec_decode_sampler: SpecDecodeBaseSampler = None
+ if draft_token_acceptance_method == "rejection_sampler":
+ spec_decode_sampler = RejectionSampler()
+ elif draft_token_acceptance_method == "typical_acceptance_sampler":
+ spec_decode_sampler = TypicalAcceptanceSampler(
+ posterior_threshold=\
+ typical_acceptance_sampler_posterior_threshold,
+ posterior_alpha=typical_acceptance_sampler_posterior_alpha,
+ )
+ logger.info(
+ "[Speculative Decoding] Configuring"
+ " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
+
+ if not disable_mqa_scorer:
+ if scorer_worker.model_runner.attn_backend.get_name(
+ ) != "FLASH_ATTN":
+ disable_mqa_scorer = True
+ logger.info(
+ "[Speculative Decoding] Disabling MQA scorer as the "
+ "MQA is only available with flash attn backend.")
+
+ if draft_model_config and \
+ draft_model_config.max_model_len < \
+ scorer_worker.model_config.max_model_len:
+ disable_mqa_scorer = True
+ logger.info(
+ "[Speculative Decoding] Disabling MQA scorer as the "
+ "draft model max_model_len is smaller than the target "
+ "model max_model_len.")
+
+ if not scorer_worker.model_runner.model_config.enforce_eager:
+ disable_mqa_scorer = True
+ logger.info(
+ "[Speculative Decoding] Disabling MQA scorer as the "
+ "target model is not running in eager mode.")
+
+ return SpecDecodeWorker(
+ proposer_worker,
+ scorer_worker,
+ disable_mqa_scorer=disable_mqa_scorer,
+ disable_logprobs=disable_logprobs,
+ disable_log_stats=disable_log_stats,
+ disable_by_batch_size=disable_by_batch_size,
+ spec_decode_sampler=spec_decode_sampler,
+ allow_zero_draft_token_step=allow_zero_draft_token_step)
+
+ def __init__(
+ self,
+ proposer_worker: ProposerWorkerBase,
+ scorer_worker: WorkerBase,
+ spec_decode_sampler: SpecDecodeBaseSampler,
+ disable_mqa_scorer: bool = False,
+ disable_logprobs: bool = False,
+ disable_log_stats: bool = False,
+ metrics_collector: Optional[AsyncMetricsCollector] = None,
+ disable_by_batch_size: Optional[int] = None,
+ allow_zero_draft_token_step: Optional[bool] = True,
+ ):
+ """
+ Create a SpecDecodeWorker.
+
+ Args:
+ proposer_worker: A worker that can produce speculative tokens for
+ sequences.
+ scorer_worker: A worker that produces probabilities of speculative
+ tokens according to some base model. Typically a vanilla vLLM
+ Worker.
+ spec_decode_sampler: A Torch module used to perform acceptance
+ sampling of the draft tokens in the verification step of
+ speculative decoding. Currently we support two different
+ types of sampler namely RejectionSampler and
+ TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
+ instance of RejectionSampler or TypicalAcceptanceSampler.
+ disable_mqa_scorer: If set to True, disable the MQA scorer and use
+ the BatchExpansionTop1Scorer instead.
+ disable_logprobs: If set to True, token log probabilities will
+ not be output in both the draft worker and the target worker.
+ If set to False, log probabilities will be output by both.
+ disable_log_stats: If set to True, disable periodic printing of
+ speculative stage times.
+ disable_by_batch_size: If the batch size is larger than this,
+ disable speculative decoding for new incoming requests.
+ metrics_collector: Helper class for collecting metrics; can be set
+ for testing purposes.
+ allow_zero_draft_token_step: whether to allow a step where the draft
+ model generates no draft token; should disallow when the tp of
+ draft model is larger than 1 (TODO: #5814)
+ """
+ self.proposer_worker = proposer_worker
+ self.scorer_worker = scorer_worker
+ scorer_runner = getattr(self.scorer_worker, "model_runner", None)
+ self.generators = scorer_runner.get_generators(
+ ) if scorer_runner else None
+ self.disable_by_batch_size = disable_by_batch_size or float("inf")
+ self.spec_decode_sampler = spec_decode_sampler
+ self._allow_zero_draft_token_step = allow_zero_draft_token_step
+ self._metrics = AsyncMetricsCollector(
+ self.spec_decode_sampler
+ ) if metrics_collector is None else metrics_collector
+ # Tracks the sequence IDs that received a bonus token ID in
+ # their last forward pass. Needed only if KV cache is being
+ # used for token generation such as in the case of MultiStepWorker.
+ self._seq_with_bonus_token_in_last_step: Set[int] = set()
+ # Tracks the currently active request ids and the sequence IDs
+ # corresponding to them
+ self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
+ # Tracks if the proposer worker uses the KV cache or not.
+
+ self.probs_dtype = self.spec_decode_sampler.probs_dtype
+ self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
+ # Lazy initialization.
+ self.scorer: SpeculativeScorer
+ self.disable_mqa_scorer = disable_mqa_scorer
+
+ # Hidden states from target model to pass to proposer
+ # in the subsequent step.
+ self.previous_hidden_states: Optional[HiddenStates] = None
+ self._disable_logprobs = disable_logprobs
+ self._disable_log_stats = disable_log_stats
+
+ def init_device(self) -> None:
+ """Initialize both scorer and proposer models.
+ """
+ # The scorer worker model is initialized first in case the proposer
+ # model has a smaller TP degree than the target worker.
+ self.scorer_worker.init_device()
+ self.proposer_worker.init_device()
+
+ # NOTE(cade): load_model is not part of the WorkerBase interface.
+ self.scorer_worker.load_model()
+ self.proposer_worker.load_model()
+
+ self._metrics.init_tensors(self.rank, device_type=self.device)
+ self.spec_decode_sampler.init_tensors(self.rank,
+ device_type=self.device)
+
+ scorer_cls: Type[SpeculativeScorer]
+ if self.disable_mqa_scorer:
+ scorer_cls = BatchExpansionTop1Scorer
+ logger.info("[Speculative Decoding] Use batch "
+ "expansion for scoring proposals.")
+ else:
+ scorer_cls = MQAScorer
+ logger.info(
+ "[Speculative Decoding] Use MQA scorer for scoring proposals.")
+
+ self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
+ device=self.device,
+ vocab_size=self._vocab_size)
+
+ self._configure_model_sampler_for_spec_decode()
+
+ def load_model(self, *args, **kwargs):
+ pass
+
+ def _configure_model_sampler_for_spec_decode(self):
+ """Configure model sampler to emit GPU tensors. This allows spec decode
+ to keep data on device without transferring to CPU and serializing,
+ which significantly reduces overhead of sampling during verification.
+
+ NOTE(cade): This breaks abstraction boundaries pretty badly. The better
+ design is to have the "move to CPU and serialize" sampling decision be
+ done outside of the model/sampler; this way the "last-mile" worker
+ object which interfaces with the scheduler can serialize and incur the
+ performance hit as necessary. This allows us to run the worker several
+ iterations in a row without incurring the "move to CPU and serialize"
+ performance penalty.
+
+ Since this requires a large change to vLLM, we defer it to later and
+ temporarily accept this broken abstraction boundary.
+
+ NOTE(cade): This will require a special check if the proposer worker
+ does not have a sampler (e.g. ngram speculation).
+ """
+ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
+ ) = True
+ (self.scorer_worker.model_runner.model.sampler.
+ should_modify_greedy_probs_inplace) = True
+ self.proposer_worker.set_include_gpu_probs_tensor()
+ self.proposer_worker.set_should_modify_greedy_probs_inplace()
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """Determine the number of cache blocks to use.
+
+ This is done by profiling the scorer model (which is typically the
+ larger of the two). Then the total memory which would be used by the
+ scorer cache is divided evenly between the proposer and scorer model KV,
+ such that the number of blocks is equal in both KV caches.
+ """
+ num_gpu_blocks, num_cpu_blocks = (
+ self.scorer_worker.determine_num_available_blocks())
+
+ scorer_cache_block_size_bytes = (
+ self.scorer_worker.get_cache_block_size_bytes())
+ proposer_cache_block_size_bytes = (
+ self.proposer_worker.get_cache_block_size_bytes())
+
+ new_num_gpu_blocks = split_num_cache_blocks_evenly(
+ scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
+ num_gpu_blocks)
+ return new_num_gpu_blocks, num_cpu_blocks
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ """Initialize the cache engine of the scorer and proposer workers.
+ """
+ self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
+ num_cpu_blocks=num_cpu_blocks)
+ self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
+ num_cpu_blocks=num_cpu_blocks)
+
+ def get_model(self) -> nn.Module:
+ return self.scorer_worker.get_model()
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ execute_model_req: Optional[ExecuteModelRequest] = None
+ ) -> List[SamplerOutput]:
+ """Perform speculative decoding on the input batch.
+ """
+ if self.rank != self._driver_rank:
+ self._run_non_driver_rank()
+ return []
+
+ if execute_model_req is None:
+ # This signals that there's no more requests to process for now.
+ # All workers are running infinite loop with broadcast_tensor_dict,
+ # and it stops the loop when the driver broadcasts an empty input.
+ # Send an empty input to notify all other workers to stop their
+ # execution loop.
+ broadcast_tensor_dict({}, src=0)
+ return []
+
+ self._track_finished_requests(execute_model_req)
+ disable_all_speculation = self._should_disable_all_speculation(
+ execute_model_req)
+ num_lookahead_slots = execute_model_req.num_lookahead_slots
+ all_prompt = True
+ atleast_one_prompt = False
+ all_zero_spec_tokens = True
+ for sgm in execute_model_req.seq_group_metadata_list:
+ all_prompt = all_prompt and sgm.is_prompt
+ atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
+ all_zero_spec_tokens = all_zero_spec_tokens and (
+ sgm.num_speculative_tokens == 0)
+
+ if all_prompt and execute_model_req.seq_group_metadata_list:
+ assert num_lookahead_slots == 0, (
+ "Prompt only runs should have num_lookahead_slots equal to 0. "
+ "This should never happen, please file a bug at "
+ "https://github.com/vllm-project/vllm/issues")
+ # Speculative decoding is disabled in the following cases:
+ # 1. Prefill phase: Speculative decoding is not
+ # used during the prefill phase.
+ # 2. Auto-disable enabled: The running queue size exceeds
+ # the specified threshold.
+ # 3. No request: There are no requests in the batch, or
+ # none of the requests in the batch have spec decoding enabled.
+ # In any of these cases, the proposer and scorer workers
+ # are called normally.
+ # We expect `num_speculative_tokens` to be None for prefills.
+ no_spec = (num_lookahead_slots == 0 or disable_all_speculation
+ or all_zero_spec_tokens)
+
+ # Broadcast how many lookahead slots are scheduled for this step, and
+ # whether all speculation is disabled, to all non-driver workers.
+
+ # This is required as if the number of draft model runs changes
+ # dynamically, the non-driver workers won't know unless we perform a
+ # communication to inform them.
+
+ # no_spec is used to signal non-driver worker about prefill vs decode
+ # stage. This is needed to ensure that order of execution of proposer
+ # and scorer is same in both driver and non-driver workers (i.e.,
+ # scorer -> proposer for prefill and proposer -> scorer in decode). This
+ # order is needed to support models like EAGLE that take scorer states
+ # as inputs.
+ broadcast_dict = dict(
+ num_lookahead_slots=num_lookahead_slots,
+ no_spec=no_spec,
+ disable_all_speculation=disable_all_speculation,
+ # When both chunked prefill and speculative decoding are enabled
+ # it is possible that the same batch contains both prefill
+ # and decodes. If that happens in the scorer we run the batch
+ # as one single forward pass. However, in the proposer we
+ # run them as 2 different batches - one for prefill and
+ # the other for decodes. The variable indicates to the non-driver
+ # worker that there are prefills as part of the speculative batch
+ # and hence it needs to run an extra prefill forward pass.
+ run_spec_proposer_for_prefill=atleast_one_prompt,
+ )
+ broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
+
+ assert execute_model_req.seq_group_metadata_list is not None, (
+ "speculative decoding requires non-None seq_group_metadata_list")
+
+ self._maybe_disable_speculative_tokens(
+ disable_all_speculation, execute_model_req.seq_group_metadata_list)
+
+ if no_spec:
+ return self._run_no_spec(execute_model_req,
+ skip_proposer=disable_all_speculation)
+ return self._run_speculative_decoding_step(execute_model_req,
+ num_lookahead_slots)
+
+ @torch.inference_mode()
+ def start_worker_execution_loop(self) -> None:
+ """Execute model loop to perform speculative decoding
+ in parallel worker."""
+ while self._run_non_driver_rank():
+ pass
+
+ def _should_disable_all_speculation(
+ self, execute_model_req: ExecuteModelRequest) -> bool:
+ # When the batch size is too large, disable speculative decoding
+ # to stop trading off throughput for latency.
+ return (execute_model_req.running_queue_size
+ >= self.disable_by_batch_size)
+
+ def _maybe_disable_speculative_tokens(
+ self, disable_all_speculation: bool,
+ seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
+ if not disable_all_speculation:
+ return
+
+ for seq_group_metadata in seq_group_metadata_list:
+ # Once num_speculative_tokens is set to 0, the spec decode
+ # of this request will be disabled forever.
+ # TODO(comaniac): We currently store spec decoding specific
+ # state in the global data structure, but we should maintain
+ # this state within spec decode worker.
+ seq_group_metadata.num_speculative_tokens = 0
+
+ def _serialize_sampler_output_no_logprobs(
+ self, execute_model_req: ExecuteModelRequest,
+ sampler_output: SamplerOutput) -> List[SamplerOutput]:
+ """
+ Creates and returns a `SamplerOutput` with only the token IDs being
+ serialized to CPU and populated in `CompletionSequenceGroupOutput`.
+ All other parameters in `CompletionSequenceGroupOutput` related to log
+ probabilities are skipped.
+
+ Args:
+ execute_model_req (ExecuteModelRequest): The model request that
+ was executed.
+ sampler_output (SamplerOutput): The output from the sampler with
+ only GPU tensors populated.
+
+ Returns:
+ SamplerOutput: A new `SamplerOutput` instance containing a list of
+ `CompletionSequenceGroupOutput` objects with only token IDs
+ populated.
+ """
+ seq_output_prompt_logprobs = [
+ seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
+ and seq.sampling_params.prompt_logprobs > 0
+ for seq in execute_model_req.seq_group_metadata_list
+ ]
+ # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
+ sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
+ # subtracting is faster than testing for equality
+ sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
+ if any(seq_output_prompt_logprobs) else \
+ sampler_output.sampled_token_ids).tolist()
+
+ seq_data_entries = [
+ (seq_id, seq_data) for sg in \
+ execute_model_req.seq_group_metadata_list \
+ for seq_id, seq_data in sg.seq_data.items()
+ ]
+ completion_seq_group_output_list: List[
+ CompletionSequenceGroupOutput] = []
+ output_index = 0
+ # Make sure the non-terminal prefill chunks are still aligned with
+ # their own empty output.
+ for idx, seq_group_meta in enumerate(
+ execute_model_req.seq_group_metadata_list):
+ needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
+ seq_id, seq_data = seq_data_entries[idx]
+ if needs_prompt_logprobs:
+ prompt_token_ids = seq_data.get_prompt_token_ids()
+
+ # Some of these sequences may belong to non-terminal chunks,
+ # which may still have to report logprobs for prompts.
+ start = 1 if seq_data._num_computed_tokens == 0 \
+ else seq_data._num_computed_tokens
+ end = (seq_data._num_computed_tokens + \
+ seq_group_meta.token_chunk_size)
+ prompt_token_ids = prompt_token_ids[start:end]
+ prompt_logprobs = [
+ create_logprobs_output(
+ token_id=p_token_id,
+ token_id_logprob_rank=-1,
+ token_id_logprob=0.0,
+ topk_token_ids=[],
+ topk_logprobs=[],
+ ) for p_token_id in prompt_token_ids
+ ]
+ else:
+ prompt_logprobs = None
+
+ # Since we can get chunks here, we dont always have a sampled token
+ # (only on last chunk) but we still have to provide an output.
+ if not seq_group_meta.do_sample:
+ completion_seq_group_output_list.append(
+ CompletionSequenceGroupOutput(
+ samples=[], prompt_logprobs=prompt_logprobs))
+ continue
+
+ # Sequence with output.
+ completion_seq_group_output_list.append(
+ create_sequence_group_output(
+ token_id=sampled_token_ids_list[output_index][0],
+ token_id_logprob_rank=-1,
+ token_id_logprob=0.0,
+ seq_id=seq_id,
+ topk_token_ids=[],
+ topk_logprobs=[],
+ prompt_logprobs=prompt_logprobs))
+ output_index += 1
+
+ return [SamplerOutput(outputs=completion_seq_group_output_list)]
+
+ @nvtx_range("spec_decode_worker._run_no_spec")
+ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
+ skip_proposer: bool) -> List[SamplerOutput]:
+ """Run a single generation step without any speculation. The input is
+ sent to the proposer and scorer model so that the KV cache is consistent
+ between the two. When skip_proposer is True, the proposer model is
+ not called, meaning that the kv-cache in proposer for requests is not
+ updated, so they cannot enable spec decode in the rest decoding.
+ """
+
+ sampler_output = self.scorer_worker.execute_model(execute_model_req)
+ assert len(sampler_output) == 1
+ sampler_output = sampler_output[0]
+
+ # Store hidden states from target model execution, BxD.
+ hidden_states = sampler_output.hidden_states
+ if hidden_states is not None:
+ # Only decodes and prefill terminal chunks need a hidden state.
+ seq_group_meta_with_hidden = [
+ sg for sg in execute_model_req.seq_group_metadata_list
+ if sg.do_sample
+ ]
+ if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
+ # Drop hidden_states with no prediction (eg non-terminal chunks)
+ hidden_states = hidden_states[
+ torch.where(sampler_output.sampled_token_ids -
+ VLLM_INVALID_TOKEN_ID)[0]]
+ if self.previous_hidden_states is None and len(
+ seq_group_meta_with_hidden):
+ self.previous_hidden_states = HiddenStates(
+ hidden_states, seq_group_meta_with_hidden)
+ elif self.previous_hidden_states and len(
+ seq_group_meta_with_hidden):
+ self.previous_hidden_states.update(hidden_states,
+ seq_group_meta_with_hidden)
+
+ if not skip_proposer:
+ # We prepare the prefill hidden states here so that there no
+ # additional complexity in worker for spec_decode vs non_spec_decode
+ # flow and execute_model doesn't need additional modifications.
+ execute_model_req.previous_hidden_states = \
+ prepare_prefill_hidden_states(
+ sampler_output.prefill_hidden_states)
+
+ self.proposer_worker.execute_model(execute_model_req)
+
+ sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
+ execute_model_req=execute_model_req, sampler_output=sampler_output)
+ if self._disable_logprobs else
+ [sampler_output])
+
+ # Clear device tensors from sampler output. This reduces communication
+ # overhead when the engine runs in a different process than the workers.
+ sampler_output.sampled_token_probs = None
+ sampler_output.sampled_token_ids = None
+ sampler_output.logprobs = None
+ return sampler_output_to_return
+
+ def _run_non_driver_rank(self) -> bool:
+ """Run proposer and verifier model in non-driver workers. This is used
+ for both speculation cases (num_lookahead_slots>0) and non-speculation
+ cases (e.g. prefill).
+
+ Returns True if there are remaining sequences to process.
+ """
+ assert self.rank != self._driver_rank
+
+ data = broadcast_tensor_dict(src=self._driver_rank)
+ if not data:
+ return False
+ num_lookahead_slots = data["num_lookahead_slots"]
+
+ # In case of prefill, scorer_worker has to be run before proposer so
+ # that the hidden states can be propagated to proposer when needed.
+ if data["no_spec"]:
+ self.scorer_worker.execute_model()
+
+ if not data["disable_all_speculation"]:
+ # Even if num_lookahead_slots is zero, we want to run the
+ # proposer model as it may have KV.
+ #
+ # We run the proposer once per lookahead slot. In the future we
+ # should delegate how many times it runs to the proposer.
+ for _ in range(max(num_lookahead_slots, 1)):
+ self.proposer_worker.execute_model()
+
+ if not data["no_spec"]:
+ self.scorer_worker.execute_model()
+ if data["run_spec_proposer_for_prefill"]:
+ self.proposer_worker.execute_model()
+
+ return True
+
+ @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
+ def _run_speculative_decoding_step(
+ self, execute_model_req: ExecuteModelRequest,
+ num_lookahead_slots: int) -> List[SamplerOutput]:
+ """Execute a single step of speculative decoding.
+
+ This invokes the proposer worker to get k speculative tokens for each
+ sequence, then scores each speculative token using the scoring worker.
+
+ When `enable_chunked_prefill` is set, scorer will batch decodes and
+ prefills, while proposer will sync its KV-cache by running an extra
+ forward on prefills.
+
+ Returns a list of SamplerOutput, each containing a single token per
+ sequence.
+ """
+ # With prefill chunking, expect requests to have prompts first
+ # so that backend gets prefill|decode.
+ assert num_lookahead_slots == execute_model_req.num_lookahead_slots
+
+ # Pass last hidden states from target model to proposer
+ execute_model_req.previous_hidden_states = self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ with Timer() as proposal_timer:
+ # Generate proposals using draft worker.
+ proposals = self.proposer_worker.get_spec_proposals(
+ execute_model_req, self._seq_with_bonus_token_in_last_step)
+
+ if not self._allow_zero_draft_token_step and proposals.no_proposals:
+ #TODO: Fix it #5814
+ raise RuntimeError("Cannot handle cases where distributed draft "
+ "workers generate no tokens")
+
+ execute_model_req.previous_hidden_states = None
+
+ with Timer() as scoring_timer:
+ proposal_scores = self.scorer.score_proposals(
+ execute_model_req,
+ proposals,
+ )
+
+ _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
+ execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
+ # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
+ # discard decodes that have already been processed by proposer.
+ non_spec_indices = [
+ idx for idx in non_spec_indices
+ if execute_model_req.seq_group_metadata_list[idx].is_prompt
+ ]
+ if len(non_spec_indices):
+ all_hidden_states = proposal_scores.hidden_states
+ if all_hidden_states is not None:
+ prefill_hidden_states = all_hidden_states[non_spec_indices]
+ execute_model_req.previous_hidden_states = \
+ prepare_prefill_hidden_states(prefill_hidden_states)
+ # Sync proposer KV cache for prefills.
+ prefill_req = execute_model_req.clone(non_spec_seqs)
+ # TODO avoid sampling here?
+ self.proposer_worker.execute_model(prefill_req)
+
+ with Timer() as verification_timer:
+ accepted_token_ids, target_logprobs = self._verify_tokens(
+ execute_model_req.seq_group_metadata_list, proposal_scores,
+ proposals, execute_model_req.num_lookahead_slots)
+
+ stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
+ scoring_timer.elapsed_time_ms,
+ verification_timer.elapsed_time_ms)
+
+ return self._create_output_sampler_list(
+ execute_model_req.seq_group_metadata_list,
+ accepted_token_ids,
+ target_logprobs=target_logprobs,
+ prompt_logprobs=proposal_scores.prompt_logprobs
+ if not self._disable_logprobs else None,
+ k=execute_model_req.num_lookahead_slots,
+ stage_times=stage_times)
+
+ @nvtx_range("spec_decode_worker._verify_tokens")
+ def _verify_tokens(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ proposal_scores: SpeculativeScores,
+ proposals: SpeculativeProposals,
+ max_proposal_len: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Determine which speculative tokens are accepted using the
+ probabilities of each token according to the proposer and scorer models.
+
+ Returns a tuple of Tensors, one for the accepted token ids and one for
+ the logprobs according to the scoring model.
+ """
+ proposal_lens_list = proposals.proposal_lens.tolist()
+
+ # vLLM currently only supports proposal lens equal to zero or the batch
+ # proposal len. This adds some complexity (splitting the batch into spec
+ # and non spec sequences) and should be removed in the future. It can be
+ # done by supporting per-sequence proposal lens.
+ (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
+ seq_group_metadata_list, proposal_lens_list)
+ original_indices = spec_indices + non_spec_indices
+
+ # Get probabilities of target model, including bonus tokens.
+ proposal_verifier_probs = proposal_scores.probs[spec_indices]
+
+ # Get non-speculative sampled tokens from target model.
+ non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
+
+ # Get bonus tokens from target model.
+ bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
+
+ # Get probabilities according to proposal method.
+ proposal_probs = proposals.proposal_probs[spec_indices]
+
+ # Get proposed tokens.
+ proposal_token_ids = proposals.proposal_token_ids[spec_indices]
+
+ # Sampler arguments
+ sampler_extra_kwargs: Dict[str, Any] = {}
+ if self.generators and isinstance(self.spec_decode_sampler,
+ SpecDecodeStochasticBaseSampler):
+ sampler_extra_kwargs["seeded_seqs"] = {
+ idx: self.generators[sgm.request_id]
+ for idx, sgm in enumerate(seq_group_metadata_list)
+ if sgm.sampling_params.seed is not None
+ }
+
+ accepted_token_ids = self.spec_decode_sampler(
+ target_with_bonus_probs=proposal_verifier_probs,
+ bonus_token_ids=bonus_token_ids,
+ draft_probs=proposal_probs,
+ draft_token_ids=proposal_token_ids,
+ **sampler_extra_kwargs,
+ )
+ # Append output tokens from non-speculative sequences to
+ # the accepted token ids tensor.
+ non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
+ 1).clone()
+ non_spec_token_ids[:, 1:] = -1
+ accepted_token_ids = torch.cat(
+ [accepted_token_ids, non_spec_token_ids])
+ logprobs = proposal_scores.logprobs
+ # Rearrange so that results are in the order of the original seq group
+ # metadata.
+ accepted_token_ids[original_indices] = accepted_token_ids.clone()
+
+ # B x K+1 x D
+ hidden_states = proposal_scores.hidden_states
+ if hidden_states is not None:
+ # Only get terminal hidden states for next step
+ terminal_metadata = [
+ sg for sg in seq_group_metadata_list if sg.do_sample
+ ]
+
+ # Contract hidden states based on accepted tokens
+ hs_size = hidden_states.shape[-1]
+ accepted_index = accepted_token_ids + 1 # Convert -1 to 0
+ accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
+ # Drop non-terminal prefill chunks hidden states.
+ hidden_states = hidden_states[accepted_index !=
+ VLLM_INVALID_TOKEN_ID]
+ accepted_index = accepted_index[accepted_index !=
+ VLLM_INVALID_TOKEN_ID]
+ assert len(accepted_index) == hidden_states.shape[0] == len(
+ terminal_metadata)
+ index = accepted_index[:, None, None].expand(-1, 1,
+ hs_size) # b x 1 x d
+ second_last_token_hidden_states = hidden_states[:, -2] # b x d
+ hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
+ # Store hidden states from target model for subsequent decode step
+ self.previous_hidden_states = HiddenStates(
+ hidden_states, terminal_metadata,
+ second_last_token_hidden_states)
+ return accepted_token_ids, logprobs
+
+ def _create_output_sampler_list(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
+ target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
+ prompt_logprobs: Optional[
+ torch.Tensor], # shape: [nprompt_tokens, vocab_size]
+ k: int,
+ stage_times: Tuple[float, float, float],
+ ) -> List[SamplerOutput]:
+ """Given the accepted token ids, create a list of SamplerOutput.
+
+ The output is padded with -1 tokens such that each sequence has
+ the same number of outputs.
+ """
+ batch_size, num_steps = accepted_token_ids.shape
+ accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
+ if self._disable_logprobs:
+ # We are skipping the logprobs. Hence don't serialize the
+ # logprobs related tensors from the GPU. Instead create
+ # empty/dummy lists.
+ (accepted_token_id_ranks_by_step,
+ accepted_token_id_logprobs_by_step,
+ topk_logprobs_by_step, topk_indices_by_step) =\
+ self._create_dummy_logprob_lists(
+ batch_size, num_steps,
+ self.scorer_worker.model_config.max_logprobs)
+ else:
+ # Organize input tensors by step instead of by sequence.
+ target_logprobs_by_step = target_logprobs.transpose(0, 1)
+ # Serialize all tensors into Python lists.
+ (accepted_token_id_ranks_by_step,
+ accepted_token_id_logprobs_by_step,
+ topk_logprobs_by_step, topk_indices_by_step) =\
+ self._create_logprob_lists_from_tensors(
+ target_logprobs_by_step, accepted_token_ids_by_step,
+ self.scorer_worker.model_config.max_logprobs)
+
+ # Get the sequence ids and num_logprobs (sampling parameter) in the
+ # batch.
+ seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
+ seq_group_metadata_list)
+
+ num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
+
+ # Serialize tensor to CPU Python list.
+ accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
+
+ # Construct the output on a per-step, per-sequence basis.
+ # Non-terminal prefill chunks will end up here as rows with just -1s
+ # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
+ # terminal chunks will only have one generated token at time 0.
+ sampler_output_list: List[SamplerOutput] = []
+
+ # Prefills are not multi-step (return at most 1 token), in order to
+ # avoid padding or repetition to fit decodes, we separate them.
+ for i, sg in enumerate(seq_group_metadata_list):
+ if not sg.is_prompt:
+ # Requests are ordered as prefills|decodes=>no more prefills.
+ break
+ num_logprobs = num_logprobs_per_seq[i]
+ seq_kwargs = dict(token_id=-1,
+ token_id_logprob_rank=0,
+ token_id_logprob=-float('inf'),
+ topk_token_ids=[-1] * num_logprobs,
+ topk_logprobs=[-float('inf')] * num_logprobs,
+ seq_id=seq_ids[i])
+ # Terminal chunk, has token.
+ if sg.do_sample:
+ seq_kwargs.update(
+ dict(
+ token_id=accepted_token_ids[i][0].item(),
+ token_id_logprob_rank=accepted_token_id_ranks_by_step[
+ 0][i],
+ token_id_logprob=accepted_token_id_logprobs_by_step[0]
+ [i],
+ topk_token_ids=topk_indices_by_step[0][i]
+ [:num_logprobs],
+ # output only so step is 0
+ topk_logprobs=topk_logprobs_by_step[0][i]
+ [:num_logprobs],
+ ))
+ needs_plogs = (sg.sampling_params.prompt_logprobs
+ and sg.sampling_params.prompt_logprobs > 0)
+ plogs = None
+ if prompt_logprobs is not None:
+ # Even non-terminal prompt chunks can have logprobs here.
+ plogs = prompt_logprobs[i]
+ elif needs_plogs:
+ # Prompt logprobs are requested but `_disable_logprobs` is set.
+ seq_data = next(iter(sg.seq_data.values()))
+ # Get only the tokens in this chunk!
+ prompt_token_ids = seq_data.get_prompt_token_ids()
+ prompt_token_ids = prompt_token_ids[
+ seq_data.
+ _num_computed_tokens:seq_data._num_computed_tokens +
+ sg.token_chunk_size]
+
+ is_first_chunk = seq_data._num_computed_tokens == 0
+ # There's no prob generated for the first token in a sequence.
+ if is_first_chunk:
+ prompt_token_ids = prompt_token_ids[1:]
+ plogs = [
+ create_logprobs_output(
+ token_id=p_token_id,
+ token_id_logprob_rank=-1,
+ token_id_logprob=0.0,
+ topk_token_ids=[],
+ topk_logprobs=[],
+ ) for p_token_id in prompt_token_ids
+ ]
+ seq_kwargs.update(dict(prompt_logprobs=plogs))
+
+ sampler_output_list.append(
+ SamplerOutput(
+ outputs=[create_sequence_group_output(
+ **seq_kwargs)])) # type: ignore
+
+ # Decodes, create one SamplerOutput per-step (at most K+1).
+ for step_index in range(num_steps):
+ if all(token_id == -1 for sg, token_id in zip(
+ seq_group_metadata_list,
+ accepted_token_ids_by_step[step_index])
+ if not sg.is_prompt):
+ break
+
+ step_output_token_ids: List[CompletionSequenceGroupOutput] = []
+ for sequence_index in range(batch_size):
+ seq_meta = seq_group_metadata_list[sequence_index]
+ # Prompts already processed above.
+ if seq_meta.is_prompt:
+ continue
+
+ # Each sequence may have a different num_logprobs; retrieve it.
+ num_logprobs = num_logprobs_per_seq[sequence_index]
+ step_output_token_ids.append(
+ create_sequence_group_output(
+ token_id=accepted_token_ids_by_step[step_index]
+ [sequence_index],
+ token_id_logprob_rank=accepted_token_id_ranks_by_step[
+ step_index][sequence_index],
+ token_id_logprob=accepted_token_id_logprobs_by_step[
+ step_index][sequence_index],
+ seq_id=seq_ids[sequence_index],
+ topk_token_ids=topk_indices_by_step[step_index]
+ [sequence_index][:num_logprobs],
+ topk_logprobs=topk_logprobs_by_step[step_index]
+ [sequence_index][:num_logprobs],
+ ))
+ sampler_output_list.append(
+ SamplerOutput(outputs=step_output_token_ids))
+
+ # Populate the data structures needed to keep track of sequences with
+ # bonus tokens.
+ self._track_sequences_with_bonus_tokens(seq_ids,
+ request_ids_seq_ids_mapping,
+ accepted_token_ids_by_step)
+ maybe_rejsample_metrics = (
+ self._metrics.maybe_collect_rejsample_metrics(k))
+ if maybe_rejsample_metrics is not None:
+ sampler_output_list[
+ 0].spec_decode_worker_metrics = maybe_rejsample_metrics
+
+ # Log time spent in each stage periodically.
+ # This is periodic because the rejection sampler emits metrics
+ # periodically.
+ self._maybe_log_stage_times(*stage_times)
+ # First `n_prefills` entries will contain prefills SamplerOutput when
+ # chunked prefill is enabled, the rest is decodes in multi-step format.
+ return sampler_output_list
+
+ def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
+ scoring_time_ms: float,
+ verification_time_ms: float) -> None:
+ """Log the speculative stage times. If stat logging is disabled, do
+ nothing.
+ """
+ if self._disable_log_stats:
+ return
+
+ logger.info(
+ "SpecDecodeWorker stage times: "
+ "average_time_per_proposal_tok_ms=%.02f "
+ "scoring_time_ms=%.02f verification_time_ms=%.02f",
+ average_time_per_proposal_tok_ms, scoring_time_ms,
+ verification_time_ms)
+
+ def _create_dummy_logprob_lists(
+ self,
+ batch_size: int,
+ num_steps: int,
+ num_top_k: int,
+ ) -> Tuple[List[List[int]], List[List[float]],
+ List[List[List[Optional[float]]]],
+ List[List[List[Optional[int]]]]]:
+ """
+ Creates and returns four dummy lists representing token probabilities
+ and their ranks.
+
+ This method initializes and returns:
+ - The ranks of the accepted tokens, shaped (num_steps, batch_size)
+ - The log probabilities of the accepted tokens,
+ shaped (num_steps, batch_size)
+ - The log probabilities of the top k tokens,
+ shaped (num_steps, batch_size, num_top_k)
+ - The token IDs of the top k tokens,
+ shaped (num_steps, batch_size, num_top_k)
+
+ Args:
+ batch_size (int): The size of the batch.
+ num_steps (int): The number of steps in the sequence.
+ num_top_k (int): The number of top-k token log probabilities to
+ return.
+
+ Returns:
+ A tuple containing four dummy lists as described above.
+ """
+ accepted_token_id_ranks_by_step = [[-1] * batch_size
+ for _ in range(num_steps)]
+ accepted_token_id_logprobs_by_step = [[0.0] * batch_size
+ for _ in range(num_steps)]
+ topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
+ [None] * num_top_k for _ in range(batch_size)
+ ] for _ in range(num_steps)]
+ topk_indices_by_step: List[List[List[Optional[int]]]] = [[
+ [None] * num_top_k for _ in range(batch_size)
+ ] for _ in range(num_steps)]
+ return (accepted_token_id_ranks_by_step,
+ accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
+ topk_indices_by_step)
+
+ def _create_logprob_lists_from_tensors(
+ self,
+ target_logprobs_by_step: torch.Tensor,
+ accepted_token_ids_by_step: torch.Tensor,
+ num_top_k: int,
+ ) -> Tuple[List[List[int]], List[List[float]],
+ List[List[List[Optional[float]]]],
+ List[List[List[Optional[int]]]]]:
+ """
+ Creates and returns four lists representing token probabilities and
+ their ranks.
+
+ This method initializes and returns four lists containing:
+ - The ranks of the accepted tokens, shaped (num_steps, batch_size)
+ - The log probabilities of the accepted tokens,
+ shaped (num_steps, batch_size)
+ - The log probabilities of the top k tokens,
+ shaped (num_steps, batch_size, num_top_k)
+ - The token IDs of the top k tokens,
+ shaped (num_steps, batch_size, num_top_k)
+
+ Args:
+ target_logprobs_by_step (torch.Tensor): Tensor representing the
+ log probabilities of the target model,
+ shaped (num_steps, batch_size, vocab_size)
+ accepted_token_ids_by_step (torch.Tensor): Tensor representing
+ the accepted token_ids, shaped (num_steps, batch_size)
+ num_top_k (int): The number of top-k token log probabilities to
+ return.
+
+ Returns:
+ A tuple containing the lists as described above.
+ """
+ # Serialize all tensors to CPU Python lists.
+ # Get the logprobs/rank of the accepted tokens.
+ (accepted_token_id_ranks_by_step_tensor,
+ accepted_token_id_logprobs_by_step_tensor
+ ) = get_sampled_token_logprobs(
+ logprob_tensor=target_logprobs_by_step,
+ sampled_token_ids=accepted_token_ids_by_step,
+ )
+ # Get the top-k logprobs (which may or may not include the
+ # logprob of the accepted token).
+ (topk_logprobs_by_step_tensor,
+ topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
+ k=num_top_k,
+ dim=-1,
+ )
+ accepted_token_id_ranks_by_step = (
+ accepted_token_id_ranks_by_step_tensor.tolist())
+ accepted_token_id_logprobs_by_step = (
+ accepted_token_id_logprobs_by_step_tensor.tolist())
+ topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
+ topk_indices_by_step = topk_indices_by_step_tensor.tolist()
+ return (accepted_token_id_ranks_by_step,
+ accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
+ topk_indices_by_step)
+
+ def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
+ """
+ Removes the finished requests and their associated sequence ids from
+ internal book keeping data structures.
+ """
+ for finished_request in execute_model_req.finished_requests_ids:
+ for seq_id in self._request_id_seq_id_mapping[finished_request]:
+ self._seq_with_bonus_token_in_last_step.discard(seq_id)
+ del self._request_id_seq_id_mapping[finished_request]
+
+ def _track_sequences_with_bonus_tokens(
+ self, seq_ids: List[int],
+ request_ids_seq_ids_mapping: Dict[str, Set[int]],
+ accepted_token_ids_by_step: List[List[int]]):
+ """
+ Updates the internal data structures which keep track of sequences
+ which have been assigned bonus tokens in their last forward pass.
+ """
+ for seq_index, seq_id in enumerate(seq_ids):
+ last_token_id = accepted_token_ids_by_step[-1][seq_index]
+ if last_token_id == -1:
+ self._seq_with_bonus_token_in_last_step.discard(seq_id)
+ else:
+ self._seq_with_bonus_token_in_last_step.add(seq_id)
+ for request_id, sequences in request_ids_seq_ids_mapping.items():
+ self._request_id_seq_id_mapping[request_id].update(sequences)
+
+ @cached_property
+ def _vocab_size(self) -> int:
+ """Get the vocab size of the model and make sure it's consistent between
+ draft and target workers.
+ """
+ vocab_sizes = [
+ worker.vocab_size
+ for worker in [self.proposer_worker, self.scorer_worker]
+ ]
+ assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
+ return vocab_sizes[0]
+
+ @property
+ def rank(self):
+ return self.scorer_worker.rank
+
+ @property
+ def device(self):
+ return self.scorer_worker.device
+
+ @property
+ def _driver_rank(self) -> int:
+ return 0
+
+ def get_cache_block_size_bytes(self):
+ """Return the size of a cache block in bytes.
+
+ This function is only used to compose workers within a SpecDecodeWorker.
+ We leave composing a SpecDecodeWorker within a SpecDecodeWorker
+ undefined for now, although it could be implemented in the future.
+ See https://arxiv.org/abs/2308.04623.
+ """
+ raise NotImplementedError
+
+ def start_profile(self):
+ if isinstance(self.scorer_worker, WorkerBase):
+ self.scorer_worker.start_profile()
+
+ def stop_profile(self):
+ if isinstance(self.scorer_worker, WorkerBase):
+ self.scorer_worker.stop_profile()
+
+
+def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
+ proposer_cache_block_size_bytes: int,
+ total_num_gpu_blocks: int) -> int:
+ """Given total_num_gpu_blocks, the number of GPU blocks that could be
+ allocate to the target model, this function calculates how many blocks
+ should be given to the draft and target model.
+
+ Note that usually the block size, in bytes, of each model is different,
+ as it's a function of number of KV/layer, number of heads, and hidden
+ dimension size.
+
+ Since the target and draft models allocate the same number of blocks, we
+ simply calculate the number of blocks where if allocated by both models,
+ the total memory usage from KV cache is no larger than the number of
+ blocks allocatable by the target model alone.
+ """
+ new_num_gpu_blocks = int(
+ total_num_gpu_blocks * scorer_cache_block_size_bytes /
+ (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
+
+ return new_num_gpu_blocks
+
+
+def prepare_prefill_hidden_states(
+ prefill_hidden_states: torch.Tensor) -> HiddenStates:
+ # For prefill step in proposer, we run the model for N-1 tokens
+ # because Nth token will be processed in the first decode step. For
+ # N-1 tokens, the input should be 0:N-1 hidden states which should
+ # be concatanated with 1:N token (since output of scorer has to be
+ # the input for proposer). Therefore, we shift the hidden states to
+ # align n-1th hidden state with nth token.
+ return HiddenStates(prefill_hidden_states.roll(
+ shifts=1, dims=0)) if prefill_hidden_states is not None else None
diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c04680a6a7ab37196633eddc1a218876164a18b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/util.py
@@ -0,0 +1,276 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import time
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Sequence, Tuple
+
+import torch
+
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.platforms import current_platform
+from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
+ PromptLogprobs, SequenceGroupMetadata,
+ SequenceOutput)
+
+SeqId = int
+
+
+def get_all_num_logprobs(
+ seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
+ """Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
+
+ If the sampling params do not call for any logprobs, return 0 for that
+ sequence.
+ """
+
+ all_num_logprobs: List[int] = []
+ for seq_group_metadata in seq_group_metadata_list:
+ num_logprobs = seq_group_metadata.sampling_params.logprobs
+ if num_logprobs is None:
+ num_logprobs = 0
+ all_num_logprobs.append(num_logprobs)
+
+ return all_num_logprobs
+
+
+def get_sampled_token_logprobs(
+ # shape [num_steps, batch_size, vocab_size]
+ logprob_tensor: torch.Tensor,
+ sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
+ """
+ num_steps, batch_size, vocab_size = logprob_tensor.shape
+
+ selected_logprobs = logprob_tensor[
+ torch.arange(num_steps).unsqueeze(1),
+ torch.arange(batch_size),
+ sampled_token_ids,
+ ]
+ expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
+ -1, -1, vocab_size)
+ sampled_token_ids_ranks = (logprob_tensor
+ > expanded_selected_logprobs).sum(-1).add_(1)
+
+ return sampled_token_ids_ranks, selected_logprobs
+
+
+def create_logprobs_output(
+ token_id: int,
+ token_id_logprob_rank: int,
+ token_id_logprob: float,
+ topk_token_ids: List[Optional[int]],
+ topk_logprobs: List[Optional[float]],
+) -> Dict[int, Logprob]:
+ """Create a Logprob Dict for a token given the sampling results.
+
+ Args:
+ token_id (int): The sampled token for the sequence.
+ token_id_logprob_rank (int): The logprob rank of the sampled token.
+ token_id_logprob (float): The logprob value of the sampled token.
+ topk_token_ids (List[Optional[int]]): The list of top-k token ids.
+ topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
+ """
+ # vLLM logprobs always include the sampled token. In addition, the user may
+ # request topk-logprobs (where top-k varies per user up to max_logprobs).
+ logprobs: Dict[int, Logprob] = {
+ token_id: Logprob(
+ logprob=token_id_logprob,
+ rank=token_id_logprob_rank,
+ ),
+ }
+ logprobs.update({
+ topk_token_id: Logprob(
+ logprob=topk_logprob if topk_logprob is not None else 0.0,
+ rank=topk_index + 1,
+ )
+ for topk_index, (topk_token_id, topk_logprob) \
+ in enumerate(zip(topk_token_ids, topk_logprobs)) \
+ if topk_token_id is not None
+ })
+
+ return logprobs
+
+
+def create_sequence_group_output(
+ token_id: int,
+ token_id_logprob_rank: int,
+ token_id_logprob: float,
+ seq_id: SeqId,
+ topk_token_ids: List[Optional[int]],
+ topk_logprobs: List[Optional[float]],
+ prompt_logprobs: Optional[PromptLogprobs] = None,
+) -> CompletionSequenceGroupOutput:
+ """Create a SequenceGroupOutput given the sampling results.
+
+ Args:
+ token_id (int): The sampled token for the sequence.
+ token_id_logprob_rank (int): The logprob rank of the sampled token.
+ token_id_logprob (float): The logprob value of the sampled token.
+ seq_id (int): The sequence id.
+ topk_token_ids (List[Optional[int]]): The list of top-k token ids.
+ topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
+ """
+
+ logprobs = create_logprobs_output(
+ token_id,
+ token_id_logprob_rank,
+ token_id_logprob,
+ topk_token_ids,
+ topk_logprobs,
+ )
+
+ return CompletionSequenceGroupOutput(
+ samples=[
+ SequenceOutput(parent_seq_id=seq_id,
+ output_token=token_id,
+ logprobs=logprobs)
+ ],
+ prompt_logprobs=prompt_logprobs,
+ )
+
+
+def split_batch_by_proposal_len(
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ proposal_lens: List[int],
+) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
+ List[SequenceGroupMetadata], List[int]]]:
+ """Utility function that splits a batch based on whether the proposal len is
+ zero or not. We should remove this once vLLM supports per-sequence proposal
+ lens in a batch.
+ """
+
+ nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
+ zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
+ for i, (seq_group, proposal_len) in enumerate(
+ zip(seq_group_metadata_list, proposal_lens)):
+ seq_groups, indices = nonzero_lists if proposal_len else zero_lists
+ seq_groups.append(seq_group)
+ indices.append(i)
+ return nonzero_lists, zero_lists
+
+
+def sampler_output_to_torch(
+ sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Utility function which converts a list of SamplerOutput to tensors.
+
+ sampler_transposed here is used as the indicator for whether
+ we need do additional tensor transpose logic here.
+
+ Returns:
+ sampled_token_ids: torch.Tensor
+ shape: [batch_size, len(sampler_output_list)]
+
+ sampled_token_probs: torch.Tensor
+ shape: [batch_size, len(sampler_output_list), vocab_size]
+ """
+
+ # shape: [batch_size, num_sampler_output, vocab_size]
+ sampled_token_probs = torch.stack(
+ [
+ sampler_output.sampled_token_probs
+ for sampler_output in sampler_output_list
+ ],
+ dim=0,
+ )
+
+ # shape: [batch_size, num_sampler_output, vocab_size]
+ sampled_token_logprobs = torch.stack(
+ [sampler_output.logprobs for sampler_output in sampler_output_list],
+ dim=0,
+ )
+
+ # shape: [batch_size, num_sampler_output]
+ sampled_token_ids = torch.stack(
+ [
+ sampler_output.sampled_token_ids.flatten()
+ for sampler_output in sampler_output_list
+ ],
+ dim=0,
+ )
+
+ if sampler_transposed:
+ sampled_token_probs = sampled_token_probs.transpose(0, 1)
+ sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
+ sampled_token_ids = sampled_token_ids.transpose(0, 1)
+
+ if sampler_output_list[0].hidden_states is not None:
+ # shape: [batch_size, num_sampler_output, hidden_dim]
+ sampled_hidden_states = torch.stack(
+ [
+ sampler_output.hidden_states
+ for sampler_output in sampler_output_list
+ ],
+ dim=0,
+ )
+
+ if sampler_transposed:
+ sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
+ else:
+ sampled_hidden_states = None
+
+ return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
+ sampled_hidden_states)
+
+
+def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
+ vocab_size: int, device: str) -> None:
+ """Helper method which mocks out the GPU tensors in SamplerOutput with dummy
+ values. This will be removed in PR 7/9.
+ https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
+ """
+ values = [
+ sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
+ ]
+ assert all(v is None for v in values) or not any(v is None for v in values)
+ if not any(v is None for v in values):
+ # Do nothing if the tensors are already created (usually in unit tests).
+ return
+
+ # Softmax to ensure valid probs.
+ sampler_output.sampled_token_probs = torch.nn.functional.softmax(
+ torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
+ dim=-1)
+
+ sampler_output.sampled_token_ids = torch.randint(low=10,
+ high=100,
+ size=(batch_size, ),
+ dtype=torch.long,
+ device=device)
+
+
+@contextmanager
+def nvtx_range(msg, *args, **kwargs):
+ """
+ Context manager / decorator that pushes an NVTX range at the beginning
+ of its scope, and pops it at the end. If extra arguments are given,
+ they are passed as arguments to msg.format().
+
+ If running with cuda graphs, you must enable nsys cuda graph profiling.
+
+ Arguments:
+ msg (string): message to associate with the range
+ """
+ if current_platform.is_cuda_alike():
+ torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
+ try:
+ yield
+ finally:
+ torch.cuda.nvtx.range_pop()
+ else:
+ yield
+
+
+class Timer:
+ """Basic timer context manager for measuring CPU time.
+ """
+
+ def __enter__(self):
+ self.start_time = time.time()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.end_time = time.time()
+ self.elapsed_time_s = self.end_time - self.start_time
+ self.elapsed_time_ms = self.elapsed_time_s * 1000