koichi12 commited on
Commit
7bc7e14
·
verified ·
1 Parent(s): a1956b1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py +169 -0
  3. .venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py +1007 -0
  4. .venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py +105 -0
  5. .venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py +1414 -0
  6. .venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py +44 -0
  7. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py +0 -0
  8. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py +911 -0
  24. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py +305 -0
  25. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py +88 -0
  26. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py +1428 -0
  27. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py +342 -0
  28. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py +955 -0
  29. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py +547 -0
  30. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py +242 -0
  31. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py +524 -0
  32. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py +244 -0
  33. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py +235 -0
  34. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py +208 -0
  35. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py +238 -0
  36. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py +146 -0
  37. .venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py +59 -0
  38. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/multimodal/audio.py +77 -0
  46. .venv/lib/python3.11/site-packages/vllm/multimodal/base.py +463 -0
  47. .venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/entrypoints/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/entrypoints/api_server.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ NOTE: This API server is used only for demonstrating usage of AsyncEngine
4
+ and simple performance benchmarks. It is not intended for production use.
5
+ For production use, we recommend using our OpenAI compatible server.
6
+ We are also not going to accept PRs modifying this file, please
7
+ change `vllm/entrypoints/openai/api_server.py` instead.
8
+ """
9
+ import asyncio
10
+ import json
11
+ import ssl
12
+ from argparse import Namespace
13
+ from typing import Any, AsyncGenerator, Optional
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
17
+
18
+ from vllm.engine.arg_utils import AsyncEngineArgs
19
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
20
+ from vllm.entrypoints.launcher import serve_http
21
+ from vllm.entrypoints.utils import with_cancellation
22
+ from vllm.logger import init_logger
23
+ from vllm.sampling_params import SamplingParams
24
+ from vllm.usage.usage_lib import UsageContext
25
+ from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
26
+ from vllm.version import __version__ as VLLM_VERSION
27
+
28
+ logger = init_logger("vllm.entrypoints.api_server")
29
+
30
+ TIMEOUT_KEEP_ALIVE = 5 # seconds.
31
+ app = FastAPI()
32
+ engine = None
33
+
34
+
35
+ @app.get("/health")
36
+ async def health() -> Response:
37
+ """Health check."""
38
+ return Response(status_code=200)
39
+
40
+
41
+ @app.post("/generate")
42
+ async def generate(request: Request) -> Response:
43
+ """Generate completion for the request.
44
+
45
+ The request should be a JSON object with the following fields:
46
+ - prompt: the prompt to use for the generation.
47
+ - stream: whether to stream the results or not.
48
+ - other fields: the sampling parameters (See `SamplingParams` for details).
49
+ """
50
+ request_dict = await request.json()
51
+ return await _generate(request_dict, raw_request=request)
52
+
53
+
54
+ @with_cancellation
55
+ async def _generate(request_dict: dict, raw_request: Request) -> Response:
56
+ prompt = request_dict.pop("prompt")
57
+ stream = request_dict.pop("stream", False)
58
+ sampling_params = SamplingParams(**request_dict)
59
+ request_id = random_uuid()
60
+
61
+ assert engine is not None
62
+ results_generator = engine.generate(prompt, sampling_params, request_id)
63
+
64
+ # Streaming case
65
+ async def stream_results() -> AsyncGenerator[bytes, None]:
66
+ async for request_output in results_generator:
67
+ prompt = request_output.prompt
68
+ assert prompt is not None
69
+ text_outputs = [
70
+ prompt + output.text for output in request_output.outputs
71
+ ]
72
+ ret = {"text": text_outputs}
73
+ yield (json.dumps(ret) + "\n").encode("utf-8")
74
+
75
+ if stream:
76
+ return StreamingResponse(stream_results())
77
+
78
+ # Non-streaming case
79
+ final_output = None
80
+ try:
81
+ async for request_output in results_generator:
82
+ final_output = request_output
83
+ except asyncio.CancelledError:
84
+ return Response(status_code=499)
85
+
86
+ assert final_output is not None
87
+ prompt = final_output.prompt
88
+ assert prompt is not None
89
+ text_outputs = [prompt + output.text for output in final_output.outputs]
90
+ ret = {"text": text_outputs}
91
+ return JSONResponse(ret)
92
+
93
+
94
+ def build_app(args: Namespace) -> FastAPI:
95
+ global app
96
+
97
+ app.root_path = args.root_path
98
+ return app
99
+
100
+
101
+ async def init_app(
102
+ args: Namespace,
103
+ llm_engine: Optional[AsyncLLMEngine] = None,
104
+ ) -> FastAPI:
105
+ app = build_app(args)
106
+
107
+ global engine
108
+
109
+ engine_args = AsyncEngineArgs.from_cli_args(args)
110
+ engine = (llm_engine
111
+ if llm_engine is not None else AsyncLLMEngine.from_engine_args(
112
+ engine_args, usage_context=UsageContext.API_SERVER))
113
+
114
+ return app
115
+
116
+
117
+ async def run_server(args: Namespace,
118
+ llm_engine: Optional[AsyncLLMEngine] = None,
119
+ **uvicorn_kwargs: Any) -> None:
120
+ logger.info("vLLM API server version %s", VLLM_VERSION)
121
+ logger.info("args: %s", args)
122
+
123
+ set_ulimit()
124
+
125
+ app = await init_app(args, llm_engine)
126
+ assert engine is not None
127
+
128
+ shutdown_task = await serve_http(
129
+ app,
130
+ host=args.host,
131
+ port=args.port,
132
+ log_level=args.log_level,
133
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
134
+ ssl_keyfile=args.ssl_keyfile,
135
+ ssl_certfile=args.ssl_certfile,
136
+ ssl_ca_certs=args.ssl_ca_certs,
137
+ ssl_cert_reqs=args.ssl_cert_reqs,
138
+ **uvicorn_kwargs,
139
+ )
140
+
141
+ await shutdown_task
142
+
143
+
144
+ if __name__ == "__main__":
145
+ parser = FlexibleArgumentParser()
146
+ parser.add_argument("--host", type=str, default=None)
147
+ parser.add_argument("--port", type=int, default=8000)
148
+ parser.add_argument("--ssl-keyfile", type=str, default=None)
149
+ parser.add_argument("--ssl-certfile", type=str, default=None)
150
+ parser.add_argument("--ssl-ca-certs",
151
+ type=str,
152
+ default=None,
153
+ help="The CA certificates file")
154
+ parser.add_argument(
155
+ "--ssl-cert-reqs",
156
+ type=int,
157
+ default=int(ssl.CERT_NONE),
158
+ help="Whether client certificate is required (see stdlib ssl module's)"
159
+ )
160
+ parser.add_argument(
161
+ "--root-path",
162
+ type=str,
163
+ default=None,
164
+ help="FastAPI root_path when app is behind a path based routing proxy")
165
+ parser.add_argument("--log-level", type=str, default="debug")
166
+ parser = AsyncEngineArgs.add_cli_args(parser)
167
+ args = parser.parse_args()
168
+
169
+ asyncio.run(run_server(args))
.venv/lib/python3.11/site-packages/vllm/entrypoints/chat_utils.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import codecs
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+ from collections import defaultdict, deque
8
+ from functools import cache, lru_cache, partial
9
+ from pathlib import Path
10
+ from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
11
+ Literal, Optional, Tuple, TypeVar, Union, cast)
12
+
13
+ import jinja2.nodes
14
+ import transformers.utils.chat_template_utils as hf_chat_utils
15
+ # yapf conflicts with isort for this block
16
+ # yapf: disable
17
+ from openai.types.chat import (ChatCompletionAssistantMessageParam,
18
+ ChatCompletionContentPartImageParam,
19
+ ChatCompletionContentPartInputAudioParam)
20
+ from openai.types.chat import (
21
+ ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
22
+ from openai.types.chat import (ChatCompletionContentPartRefusalParam,
23
+ ChatCompletionContentPartTextParam)
24
+ from openai.types.chat import (
25
+ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
26
+ from openai.types.chat import (ChatCompletionMessageToolCallParam,
27
+ ChatCompletionToolMessageParam)
28
+ from openai.types.chat.chat_completion_content_part_input_audio_param import (
29
+ InputAudio)
30
+ # yapf: enable
31
+ # pydantic needs the TypedDict from typing_extensions
32
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
33
+ from typing_extensions import Required, TypeAlias, TypedDict
34
+
35
+ from vllm.config import ModelConfig
36
+ from vllm.logger import init_logger
37
+ from vllm.multimodal import MultiModalDataDict
38
+ from vllm.multimodal.utils import MediaConnector
39
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
40
+
41
+ logger = init_logger(__name__)
42
+
43
+
44
+ class AudioURL(TypedDict, total=False):
45
+ url: Required[str]
46
+ """
47
+ Either a URL of the audio or a data URL with base64 encoded audio data.
48
+ """
49
+
50
+
51
+ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
52
+ audio_url: Required[AudioURL]
53
+
54
+ type: Required[Literal["audio_url"]]
55
+ """The type of the content part."""
56
+
57
+
58
+ class VideoURL(TypedDict, total=False):
59
+ url: Required[str]
60
+ """
61
+ Either a URL of the video or a data URL with base64 encoded video data.
62
+ """
63
+
64
+
65
+ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
66
+ video_url: Required[VideoURL]
67
+
68
+ type: Required[Literal["video_url"]]
69
+ """The type of the content part."""
70
+
71
+
72
+ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
73
+ """A simpler version of the param that only accepts a plain image_url.
74
+ This is supported by OpenAI API, although it is not documented.
75
+
76
+ Example:
77
+ {
78
+ "image_url": "https://example.com/image.jpg"
79
+ }
80
+ """
81
+ image_url: Required[str]
82
+
83
+
84
+ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
85
+ """A simpler version of the param that only accepts a plain audio_url.
86
+
87
+ Example:
88
+ {
89
+ "audio_url": "https://example.com/audio.mp3"
90
+ }
91
+ """
92
+ audio_url: Required[str]
93
+
94
+
95
+ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
96
+ """A simpler version of the param that only accepts a plain audio_url.
97
+
98
+ Example:
99
+ {
100
+ "video_url": "https://example.com/video.mp4"
101
+ }
102
+ """
103
+ video_url: Required[str]
104
+
105
+
106
+ ChatCompletionContentPartParam: TypeAlias = Union[
107
+ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
108
+ ChatCompletionContentPartInputAudioParam,
109
+ ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
110
+ CustomChatCompletionContentSimpleImageParam,
111
+ CustomChatCompletionContentSimpleAudioParam,
112
+ CustomChatCompletionContentSimpleVideoParam, str]
113
+
114
+
115
+ class CustomChatCompletionMessageParam(TypedDict, total=False):
116
+ """Enables custom roles in the Chat Completion API."""
117
+ role: Required[str]
118
+ """The role of the message's author."""
119
+
120
+ content: Union[str, List[ChatCompletionContentPartParam]]
121
+ """The contents of the message."""
122
+
123
+ name: str
124
+ """An optional name for the participant.
125
+
126
+ Provides the model information to differentiate between participants of the
127
+ same role.
128
+ """
129
+
130
+ tool_call_id: Optional[str]
131
+ """Tool call that this message is responding to."""
132
+
133
+ tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
134
+ """The tool calls generated by the model, such as function calls."""
135
+
136
+
137
+ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
138
+ CustomChatCompletionMessageParam]
139
+
140
+
141
+ # TODO: Make fields ReadOnly once mypy supports it
142
+ class ConversationMessage(TypedDict, total=False):
143
+ role: Required[str]
144
+ """The role of the message's author."""
145
+
146
+ content: Union[Optional[str], List[Dict[str, str]]]
147
+ """The contents of the message"""
148
+
149
+ tool_call_id: Optional[str]
150
+ """Tool call that this message is responding to."""
151
+
152
+ name: Optional[str]
153
+ """The name of the function to call"""
154
+
155
+ tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
156
+ """The tool calls generated by the model, such as function calls."""
157
+
158
+
159
+ # Passed in by user
160
+ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
161
+
162
+ # Used internally
163
+ _ChatTemplateContentFormat = Literal["string", "openai"]
164
+
165
+
166
+ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
167
+ if isinstance(node, jinja2.nodes.Name):
168
+ return node.ctx == "load" and node.name == varname
169
+
170
+ return False
171
+
172
+
173
+ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
174
+ if isinstance(node, jinja2.nodes.Getitem):
175
+ return (_is_var_access(node.node, varname)
176
+ and isinstance(node.arg, jinja2.nodes.Const)
177
+ and node.arg.value == key)
178
+
179
+ if isinstance(node, jinja2.nodes.Getattr):
180
+ return _is_var_access(node.node, varname) and node.attr == key
181
+
182
+ return False
183
+
184
+
185
+ def _is_var_or_elems_access(
186
+ node: jinja2.nodes.Node,
187
+ varname: str,
188
+ key: Optional[str] = None,
189
+ ) -> bool:
190
+ if isinstance(node, jinja2.nodes.Filter):
191
+ return (node.node is not None
192
+ and _is_var_or_elems_access(node.node, varname, key))
193
+ if isinstance(node, jinja2.nodes.Test):
194
+ return _is_var_or_elems_access(node.node, varname, key)
195
+
196
+ if (isinstance(node, jinja2.nodes.Getitem)
197
+ and isinstance(node.arg, jinja2.nodes.Slice)):
198
+ return _is_var_or_elems_access(node.node, varname, key)
199
+
200
+ # yapf: disable
201
+ return (
202
+ _is_attr_access(node, varname, key) if key
203
+ else _is_var_access(node, varname)
204
+ ) # yapf: enable
205
+
206
+
207
+ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
208
+ # Global variable that is implicitly defined at the root
209
+ yield root, varname
210
+
211
+ # Iterative BFS
212
+ related_varnames = deque([varname])
213
+ while related_varnames:
214
+ related_varname = related_varnames.popleft()
215
+
216
+ for assign_ast in root.find_all(jinja2.nodes.Assign):
217
+ lhs = assign_ast.target
218
+ rhs = assign_ast.node
219
+
220
+ if _is_var_or_elems_access(rhs, related_varname):
221
+ assert isinstance(lhs, jinja2.nodes.Name)
222
+ yield assign_ast, lhs.name
223
+
224
+ # Avoid infinite looping for self-assignment
225
+ if lhs.name != related_varname:
226
+ related_varnames.append(lhs.name)
227
+
228
+
229
+ # NOTE: The proper way to handle this is to build a CFG so that we can handle
230
+ # the scope in which each variable is defined, but that is too complicated
231
+ def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
232
+ messages_varnames = [
233
+ varname
234
+ for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
235
+ ]
236
+
237
+ # Search for {%- for message in messages -%} loops
238
+ for loop_ast in root.find_all(jinja2.nodes.For):
239
+ loop_iter = loop_ast.iter
240
+ loop_target = loop_ast.target
241
+
242
+ for varname in messages_varnames:
243
+ if _is_var_or_elems_access(loop_iter, varname):
244
+ assert isinstance(loop_target, jinja2.nodes.Name)
245
+ yield loop_ast, loop_target.name
246
+ break
247
+
248
+
249
+ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
250
+ message_varnames = [
251
+ varname for _, varname in _iter_nodes_assign_messages_item(root)
252
+ ]
253
+
254
+ # Search for {%- for content in message['content'] -%} loops
255
+ for loop_ast in root.find_all(jinja2.nodes.For):
256
+ loop_iter = loop_ast.iter
257
+ loop_target = loop_ast.target
258
+
259
+ for varname in message_varnames:
260
+ if _is_var_or_elems_access(loop_iter, varname, "content"):
261
+ assert isinstance(loop_target, jinja2.nodes.Name)
262
+ yield loop_ast, loop_target.name
263
+ break
264
+
265
+
266
+ def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
267
+ try:
268
+ jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
269
+ return jinja_compiled.environment.parse(chat_template)
270
+ except Exception:
271
+ logger.exception("Error when compiling Jinja template")
272
+ return None
273
+
274
+
275
+ def _detect_content_format(
276
+ chat_template: str,
277
+ *,
278
+ default: _ChatTemplateContentFormat,
279
+ ) -> _ChatTemplateContentFormat:
280
+ jinja_ast = _try_extract_ast(chat_template)
281
+ if jinja_ast is None:
282
+ return default
283
+
284
+ try:
285
+ next(_iter_nodes_assign_content_item(jinja_ast))
286
+ except StopIteration:
287
+ return "string"
288
+ except Exception:
289
+ logger.exception("Error when parsing AST of Jinja template")
290
+ return default
291
+ else:
292
+ return "openai"
293
+
294
+
295
+ def _resolve_chat_template_content_format(
296
+ chat_template: Optional[str],
297
+ given_format: ChatTemplateContentFormatOption,
298
+ tokenizer: AnyTokenizer,
299
+ ) -> _ChatTemplateContentFormat:
300
+ if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
301
+ tokenizer_chat_template = tokenizer.chat_template
302
+ else:
303
+ tokenizer_chat_template = None
304
+
305
+ jinja_text: Optional[str]
306
+ if isinstance(tokenizer_chat_template, str) and chat_template is None:
307
+ jinja_text = tokenizer_chat_template
308
+ elif (isinstance(tokenizer_chat_template, dict)
309
+ and chat_template in tokenizer_chat_template):
310
+ jinja_text = tokenizer_chat_template[chat_template]
311
+ else:
312
+ jinja_text = load_chat_template(chat_template, is_literal=True)
313
+
314
+ detected_format = ("string" if jinja_text is None else
315
+ _detect_content_format(jinja_text, default="string"))
316
+
317
+ return detected_format if given_format == "auto" else given_format
318
+
319
+
320
+ @lru_cache
321
+ def resolve_chat_template_content_format(
322
+ chat_template: Optional[str],
323
+ given_format: ChatTemplateContentFormatOption,
324
+ tokenizer: AnyTokenizer,
325
+ ) -> _ChatTemplateContentFormat:
326
+ detected_format = _resolve_chat_template_content_format(
327
+ chat_template,
328
+ given_format,
329
+ tokenizer,
330
+ )
331
+
332
+ logger.info(
333
+ "Detected the chat template content format to be '%s'. "
334
+ "You can set `--chat-template-content-format` to override this.",
335
+ detected_format,
336
+ )
337
+
338
+ if given_format != "auto" and given_format != detected_format:
339
+ logger.warning(
340
+ "You specified `--chat-template-content-format %s` "
341
+ "which is different from the detected format '%s'. "
342
+ "If our automatic detection is incorrect, please consider "
343
+ "opening a GitHub issue so that we can improve it: "
344
+ "https://github.com/vllm-project/vllm/issues/new/choose",
345
+ given_format,
346
+ detected_format,
347
+ )
348
+
349
+ return detected_format
350
+
351
+
352
+ ModalityStr = Literal["image", "audio", "video"]
353
+ _T = TypeVar("_T")
354
+
355
+
356
+ class BaseMultiModalItemTracker(ABC, Generic[_T]):
357
+ """
358
+ Tracks multi-modal items in a given request and ensures that the number
359
+ of multi-modal items in a given request does not exceed the configured
360
+ maximum per prompt.
361
+ """
362
+
363
+ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
364
+ super().__init__()
365
+
366
+ self._model_config = model_config
367
+ self._tokenizer = tokenizer
368
+ self._allowed_items = (model_config.multimodal_config.limit_per_prompt
369
+ if model_config.multimodal_config else {})
370
+
371
+ self._items_by_modality = defaultdict[str, list[_T]](list)
372
+
373
+ @property
374
+ def model_config(self) -> ModelConfig:
375
+ return self._model_config
376
+
377
+ @property
378
+ def allowed_local_media_path(self):
379
+ return self._model_config.allowed_local_media_path
380
+
381
+ @staticmethod
382
+ @cache
383
+ def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
384
+ return tokenizer.decode(token_index)
385
+
386
+ def _placeholder_str(self, modality: ModalityStr,
387
+ current_count: int) -> Optional[str]:
388
+ # TODO: Let user specify how to insert image tokens into prompt
389
+ # (similar to chat template)
390
+ hf_config = self._model_config.hf_config
391
+ model_type = hf_config.model_type
392
+
393
+ if modality == "image":
394
+ if model_type == "phi3_v":
395
+ # Workaround since this token is not defined in the tokenizer
396
+ return f"<|image_{current_count}|>"
397
+ if model_type in ("minicpmo", "minicpmv"):
398
+ return "(<image>./</image>)"
399
+ if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
400
+ "pixtral"):
401
+ # These models do not use image tokens in the prompt
402
+ return None
403
+ if model_type == "qwen":
404
+ return f"Picture {current_count}: <img></img>"
405
+ if model_type.startswith("llava"):
406
+ return self._cached_token_str(self._tokenizer,
407
+ hf_config.image_token_index)
408
+ if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
409
+ "NVLM_D", "h2ovl_chat"):
410
+ return "<image>"
411
+ if model_type == "mllama":
412
+ return "<|image|>"
413
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
414
+ return "<|vision_start|><|image_pad|><|vision_end|>"
415
+ if model_type == "molmo":
416
+ return ""
417
+ if model_type == "idefics3":
418
+ return "<image>"
419
+ if model_type == "aria":
420
+ return "<|fim_prefix|><|img|><|fim_suffix|>"
421
+
422
+ raise TypeError(f"Unknown {modality} model type: {model_type}")
423
+ elif modality == "audio":
424
+ if model_type == "ultravox":
425
+ return "<|audio|>"
426
+ if model_type == "qwen2_audio":
427
+ return (f"Audio {current_count}: "
428
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>")
429
+ if model_type == "minicpmo":
430
+ return "(<audio>./</audio>)"
431
+ raise TypeError(f"Unknown model type: {model_type}")
432
+ elif modality == "video":
433
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
434
+ return "<|vision_start|><|video_pad|><|vision_end|>"
435
+ if model_type in ("minicpmo", "minicpmv"):
436
+ return "(<video>./</video>)"
437
+ if model_type.startswith("llava"):
438
+ return self._cached_token_str(self._tokenizer,
439
+ hf_config.video_token_index)
440
+ raise TypeError(f"Unknown {modality} model type: {model_type}")
441
+ else:
442
+ raise TypeError(f"Unknown modality: {modality}")
443
+
444
+ def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
445
+ """
446
+ Add a multi-modal item to the current prompt and returns the
447
+ placeholder string to use, if any.
448
+ """
449
+ allowed_count = self._allowed_items.get(modality, 1)
450
+ current_count = len(self._items_by_modality[modality]) + 1
451
+ if current_count > allowed_count:
452
+ raise ValueError(
453
+ f"At most {allowed_count} {modality}(s) may be provided in "
454
+ "one request.")
455
+
456
+ self._items_by_modality[modality].append(item)
457
+
458
+ return self._placeholder_str(modality, current_count)
459
+
460
+ @abstractmethod
461
+ def create_parser(self) -> "BaseMultiModalContentParser":
462
+ raise NotImplementedError
463
+
464
+
465
+ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
466
+
467
+ def all_mm_data(self) -> Optional[MultiModalDataDict]:
468
+ if self._items_by_modality:
469
+ return dict(self._items_by_modality)
470
+
471
+ return None
472
+
473
+ def create_parser(self) -> "BaseMultiModalContentParser":
474
+ return MultiModalContentParser(self)
475
+
476
+
477
+ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
478
+
479
+ async def all_mm_data(self) -> Optional[MultiModalDataDict]:
480
+ if self._items_by_modality:
481
+ return {
482
+ modality: await asyncio.gather(*items)
483
+ for modality, items in self._items_by_modality.items()
484
+ }
485
+
486
+ return None
487
+
488
+ def create_parser(self) -> "BaseMultiModalContentParser":
489
+ return AsyncMultiModalContentParser(self)
490
+
491
+
492
+ class BaseMultiModalContentParser(ABC):
493
+
494
+ def __init__(self) -> None:
495
+ super().__init__()
496
+
497
+ # multimodal placeholder_string : count
498
+ self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
499
+
500
+ def _add_placeholder(self, placeholder: Optional[str]):
501
+ if placeholder:
502
+ self._placeholder_counts[placeholder] += 1
503
+
504
+ def mm_placeholder_counts(self) -> Dict[str, int]:
505
+ return dict(self._placeholder_counts)
506
+
507
+ @abstractmethod
508
+ def parse_image(self, image_url: str) -> None:
509
+ raise NotImplementedError
510
+
511
+ @abstractmethod
512
+ def parse_audio(self, audio_url: str) -> None:
513
+ raise NotImplementedError
514
+
515
+ @abstractmethod
516
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
517
+ raise NotImplementedError
518
+
519
+ @abstractmethod
520
+ def parse_video(self, video_url: str) -> None:
521
+ raise NotImplementedError
522
+
523
+
524
+ class MultiModalContentParser(BaseMultiModalContentParser):
525
+
526
+ def __init__(self, tracker: MultiModalItemTracker) -> None:
527
+ super().__init__()
528
+
529
+ self._tracker = tracker
530
+
531
+ self._connector = MediaConnector(
532
+ allowed_local_media_path=tracker.allowed_local_media_path,
533
+ )
534
+
535
+ def parse_image(self, image_url: str) -> None:
536
+ image = self._connector.fetch_image(image_url)
537
+
538
+ placeholder = self._tracker.add("image", image)
539
+ self._add_placeholder(placeholder)
540
+
541
+ def parse_audio(self, audio_url: str) -> None:
542
+ audio = self._connector.fetch_audio(audio_url)
543
+
544
+ placeholder = self._tracker.add("audio", audio)
545
+ self._add_placeholder(placeholder)
546
+
547
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
548
+ audio_data = input_audio.get("data", "")
549
+ audio_format = input_audio.get("format", "")
550
+ audio_url = f"data:audio/{audio_format};base64,{audio_data}"
551
+
552
+ return self.parse_audio(audio_url)
553
+
554
+ def parse_video(self, video_url: str) -> None:
555
+ video = self._connector.fetch_video(video_url)
556
+
557
+ placeholder = self._tracker.add("video", video)
558
+ self._add_placeholder(placeholder)
559
+
560
+
561
+ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
562
+
563
+ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
564
+ super().__init__()
565
+
566
+ self._tracker = tracker
567
+ self._connector = MediaConnector(
568
+ allowed_local_media_path=tracker.allowed_local_media_path,
569
+ )
570
+
571
+ def parse_image(self, image_url: str) -> None:
572
+ image_coro = self._connector.fetch_image_async(image_url)
573
+
574
+ placeholder = self._tracker.add("image", image_coro)
575
+ self._add_placeholder(placeholder)
576
+
577
+ def parse_audio(self, audio_url: str) -> None:
578
+ audio_coro = self._connector.fetch_audio_async(audio_url)
579
+
580
+ placeholder = self._tracker.add("audio", audio_coro)
581
+ self._add_placeholder(placeholder)
582
+
583
+ def parse_input_audio(self, input_audio: InputAudio) -> None:
584
+ audio_data = input_audio.get("data", "")
585
+ audio_format = input_audio.get("format", "")
586
+ audio_url = f"data:audio/{audio_format};base64,{audio_data}"
587
+
588
+ return self.parse_audio(audio_url)
589
+
590
+ def parse_video(self, video_url: str) -> None:
591
+ video = self._connector.fetch_video_async(video_url)
592
+
593
+ placeholder = self._tracker.add("video", video)
594
+ self._add_placeholder(placeholder)
595
+
596
+
597
+ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
598
+ """Raises if the provided chat template appears invalid."""
599
+ if chat_template is None:
600
+ return
601
+
602
+ elif isinstance(chat_template, Path) and not chat_template.exists():
603
+ raise FileNotFoundError(
604
+ "the supplied chat template path doesn't exist")
605
+
606
+ elif isinstance(chat_template, str):
607
+ JINJA_CHARS = "{}\n"
608
+ if not any(c in chat_template
609
+ for c in JINJA_CHARS) and not Path(chat_template).exists():
610
+ raise ValueError(
611
+ f"The supplied chat template string ({chat_template}) "
612
+ f"appears path-like, but doesn't exist!")
613
+
614
+ else:
615
+ raise TypeError(
616
+ f"{type(chat_template)} is not a valid chat template type")
617
+
618
+
619
+ def load_chat_template(
620
+ chat_template: Optional[Union[Path, str]],
621
+ *,
622
+ is_literal: bool = False,
623
+ ) -> Optional[str]:
624
+ if chat_template is None:
625
+ return None
626
+
627
+ if is_literal:
628
+ if isinstance(chat_template, Path):
629
+ raise TypeError("chat_template is expected to be read directly "
630
+ "from its value")
631
+
632
+ return codecs.decode(chat_template, "unicode_escape")
633
+
634
+ try:
635
+ with open(chat_template) as f:
636
+ return f.read()
637
+ except OSError as e:
638
+ if isinstance(chat_template, Path):
639
+ raise
640
+
641
+ JINJA_CHARS = "{}\n"
642
+ if not any(c in chat_template for c in JINJA_CHARS):
643
+ msg = (f"The supplied chat template ({chat_template}) "
644
+ f"looks like a file path, but it failed to be "
645
+ f"opened. Reason: {e}")
646
+ raise ValueError(msg) from e
647
+
648
+ # If opening a file fails, set chat template to be args to
649
+ # ensure we decode so our escape are interpreted correctly
650
+ return load_chat_template(chat_template, is_literal=True)
651
+
652
+
653
+ # TODO: Let user specify how to insert multimodal tokens into prompt
654
+ # (similar to chat template)
655
+ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
656
+ text_prompt: str) -> str:
657
+ """Combine multimodal prompts for a multimodal language model."""
658
+
659
+ # Look through the text prompt to check for missing placeholders
660
+ missing_placeholders: List[str] = []
661
+ for placeholder in placeholder_counts:
662
+
663
+ # For any existing placeholder in the text prompt, we leave it as is
664
+ placeholder_counts[placeholder] -= text_prompt.count(placeholder)
665
+
666
+ if placeholder_counts[placeholder] < 0:
667
+ raise ValueError(
668
+ f"Found more '{placeholder}' placeholders in input prompt than "
669
+ "actual multimodal data items.")
670
+
671
+ missing_placeholders.extend([placeholder] *
672
+ placeholder_counts[placeholder])
673
+
674
+ # NOTE: For now we always add missing placeholders at the front of
675
+ # the prompt. This may change to be customizable in the future.
676
+ return "\n".join(missing_placeholders + [text_prompt])
677
+
678
+
679
+ # No need to validate using Pydantic again
680
+ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
681
+ _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
682
+ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
683
+ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
684
+ _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
685
+ _VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
686
+
687
+ _ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
688
+
689
+ # Define a mapping from part types to their corresponding parsing functions.
690
+ MM_PARSER_MAP: Dict[
691
+ str,
692
+ Callable[[ChatCompletionContentPartParam], _ContentPart],
693
+ ] = {
694
+ "text":
695
+ lambda part: _TextParser(part).get("text", ""),
696
+ "image_url":
697
+ lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
698
+ "audio_url":
699
+ lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
700
+ "input_audio":
701
+ lambda part: _InputAudioParser(part).get("input_audio", {}),
702
+ "refusal":
703
+ lambda part: _RefusalParser(part).get("refusal", ""),
704
+ "video_url":
705
+ lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
706
+ }
707
+
708
+
709
+ def _parse_chat_message_content_mm_part(
710
+ part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
711
+ """
712
+ Parses a given multi-modal content part based on its type.
713
+
714
+ Args:
715
+ part: A dict containing the content part, with a potential 'type' field.
716
+
717
+ Returns:
718
+ A tuple (part_type, content) where:
719
+ - part_type: Type of the part (e.g., 'text', 'image_url').
720
+ - content: Parsed content (e.g., text, image URL).
721
+
722
+ Raises:
723
+ ValueError: If the 'type' field is missing and no direct URL is found.
724
+ """
725
+ assert isinstance(
726
+ part, dict) # This is needed to avoid mypy errors: part.get() from str
727
+ part_type = part.get("type", None)
728
+
729
+ if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
730
+ content = MM_PARSER_MAP[part_type](part)
731
+
732
+ # Special case for 'image_url.detail'
733
+ # We only support 'auto', which is the default
734
+ if part_type == "image_url" and part.get("detail", "auto") != "auto":
735
+ logger.warning("'image_url.detail' is currently not supported "
736
+ "and will be ignored.")
737
+
738
+ return part_type, content
739
+
740
+ # Handle missing 'type' but provided direct URL fields.
741
+ # 'type' is required field by pydantic
742
+ if part_type is None:
743
+ if part.get("image_url") is not None:
744
+ image_params = cast(CustomChatCompletionContentSimpleImageParam,
745
+ part)
746
+ return "image_url", image_params.get("image_url", "")
747
+ if part.get("audio_url") is not None:
748
+ audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
749
+ part)
750
+ return "audio_url", audio_params.get("audio_url", "")
751
+ if part.get("input_audio") is not None:
752
+ input_audio_params = cast(Dict[str, str], part)
753
+ return "input_audio", input_audio_params
754
+ if part.get("video_url") is not None:
755
+ video_params = cast(CustomChatCompletionContentSimpleVideoParam,
756
+ part)
757
+ return "video_url", video_params.get("video_url", "")
758
+ # Raise an error if no 'type' or direct URL is found.
759
+ raise ValueError("Missing 'type' field in multimodal part.")
760
+
761
+ if not isinstance(part_type, str):
762
+ raise ValueError("Invalid 'type' field in multimodal part.")
763
+ return part_type, "unknown part_type content"
764
+
765
+
766
+ VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
767
+ "audio_url", "input_audio", "video_url")
768
+
769
+
770
+ def _parse_chat_message_content_parts(
771
+ role: str,
772
+ parts: Iterable[ChatCompletionContentPartParam],
773
+ mm_tracker: BaseMultiModalItemTracker,
774
+ *,
775
+ wrap_dicts: bool,
776
+ ) -> List[ConversationMessage]:
777
+ content = list[_ContentPart]()
778
+
779
+ mm_parser = mm_tracker.create_parser()
780
+
781
+ for part in parts:
782
+ parse_res = _parse_chat_message_content_part(
783
+ part,
784
+ mm_parser,
785
+ wrap_dicts=wrap_dicts,
786
+ )
787
+ if parse_res:
788
+ content.append(parse_res)
789
+
790
+ if wrap_dicts:
791
+ # Parsing wraps images and texts as interleaved dictionaries
792
+ return [ConversationMessage(role=role,
793
+ content=content)] # type: ignore
794
+ texts = cast(List[str], content)
795
+ text_prompt = "\n".join(texts)
796
+ mm_placeholder_counts = mm_parser.mm_placeholder_counts()
797
+ if mm_placeholder_counts:
798
+ text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
799
+ text_prompt)
800
+ return [ConversationMessage(role=role, content=text_prompt)]
801
+
802
+
803
+ def _parse_chat_message_content_part(
804
+ part: ChatCompletionContentPartParam,
805
+ mm_parser: BaseMultiModalContentParser,
806
+ *,
807
+ wrap_dicts: bool,
808
+ ) -> Optional[_ContentPart]:
809
+ """Parses a single part of a conversation. If wrap_dicts is True,
810
+ structured dictionary pieces for texts and images will be
811
+ wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
812
+ {"type": "image"}, respectively. Otherwise multimodal data will be
813
+ handled by mm_parser, and texts will be returned as strings to be joined
814
+ with multimodal placeholders.
815
+ """
816
+ if isinstance(part, str): # Handle plain text parts
817
+ return part
818
+
819
+ # Handle structured dictionary parts
820
+ part_type, content = _parse_chat_message_content_mm_part(part)
821
+
822
+ # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
823
+ # content is empty, log a warning and skip
824
+ if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
825
+ logger.warning(
826
+ "Skipping multimodal part (type: '%s')"
827
+ "with empty / unparsable content.", part_type)
828
+ return None
829
+
830
+ if part_type in ("text", "refusal"):
831
+ str_content = cast(str, content)
832
+ if wrap_dicts:
833
+ return {'type': 'text', 'text': str_content}
834
+ else:
835
+ return str_content
836
+
837
+ if part_type == "image_url":
838
+ str_content = cast(str, content)
839
+ mm_parser.parse_image(str_content)
840
+ return {'type': 'image'} if wrap_dicts else None
841
+
842
+ if part_type == "audio_url":
843
+ str_content = cast(str, content)
844
+ mm_parser.parse_audio(str_content)
845
+ return {'type': 'audio'} if wrap_dicts else None
846
+
847
+ if part_type == "input_audio":
848
+ dict_content = cast(InputAudio, content)
849
+ mm_parser.parse_input_audio(dict_content)
850
+ return {'type': 'audio'} if wrap_dicts else None
851
+
852
+ if part_type == "video_url":
853
+ str_content = cast(str, content)
854
+ mm_parser.parse_video(str_content)
855
+ return {'type': 'video'} if wrap_dicts else None
856
+
857
+ raise NotImplementedError(f"Unknown part type: {part_type}")
858
+
859
+
860
+ # No need to validate using Pydantic again
861
+ _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
862
+ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
863
+
864
+
865
+ def _parse_chat_message_content(
866
+ message: ChatCompletionMessageParam,
867
+ mm_tracker: BaseMultiModalItemTracker,
868
+ content_format: _ChatTemplateContentFormat,
869
+ ) -> List[ConversationMessage]:
870
+ role = message["role"]
871
+ content = message.get("content")
872
+
873
+ if content is None:
874
+ content = []
875
+ elif isinstance(content, str):
876
+ content = [
877
+ ChatCompletionContentPartTextParam(type="text", text=content)
878
+ ]
879
+ result = _parse_chat_message_content_parts(
880
+ role,
881
+ content, # type: ignore
882
+ mm_tracker,
883
+ wrap_dicts=(content_format == "openai"),
884
+ )
885
+
886
+ for result_msg in result:
887
+ if role == 'assistant':
888
+ parsed_msg = _AssistantParser(message)
889
+
890
+ if "tool_calls" in parsed_msg:
891
+ result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
892
+ elif role == "tool":
893
+ parsed_msg = _ToolParser(message)
894
+ if "tool_call_id" in parsed_msg:
895
+ result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
896
+
897
+ if "name" in message and isinstance(message["name"], str):
898
+ result_msg["name"] = message["name"]
899
+
900
+ return result
901
+
902
+
903
+ def _postprocess_messages(messages: List[ConversationMessage]) -> None:
904
+ # per the Transformers docs & maintainers, tool call arguments in
905
+ # assistant-role messages with tool_calls need to be dicts not JSON str -
906
+ # this is how tool-use chat templates will expect them moving forwards
907
+ # so, for messages that have tool_calls, parse the string (which we get
908
+ # from openAI format) to dict
909
+ for message in messages:
910
+ if (message["role"] == "assistant" and "tool_calls" in message
911
+ and isinstance(message["tool_calls"], list)):
912
+
913
+ for item in message["tool_calls"]:
914
+ item["function"]["arguments"] = json.loads(
915
+ item["function"]["arguments"])
916
+
917
+
918
+ def parse_chat_messages(
919
+ messages: List[ChatCompletionMessageParam],
920
+ model_config: ModelConfig,
921
+ tokenizer: AnyTokenizer,
922
+ content_format: _ChatTemplateContentFormat,
923
+ ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
924
+ conversation: List[ConversationMessage] = []
925
+ mm_tracker = MultiModalItemTracker(model_config, tokenizer)
926
+
927
+ for msg in messages:
928
+ sub_messages = _parse_chat_message_content(
929
+ msg,
930
+ mm_tracker,
931
+ content_format,
932
+ )
933
+
934
+ conversation.extend(sub_messages)
935
+
936
+ _postprocess_messages(conversation)
937
+
938
+ return conversation, mm_tracker.all_mm_data()
939
+
940
+
941
+ def parse_chat_messages_futures(
942
+ messages: List[ChatCompletionMessageParam],
943
+ model_config: ModelConfig,
944
+ tokenizer: AnyTokenizer,
945
+ content_format: _ChatTemplateContentFormat,
946
+ ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
947
+ conversation: List[ConversationMessage] = []
948
+ mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
949
+
950
+ for msg in messages:
951
+ sub_messages = _parse_chat_message_content(
952
+ msg,
953
+ mm_tracker,
954
+ content_format,
955
+ )
956
+
957
+ conversation.extend(sub_messages)
958
+
959
+ _postprocess_messages(conversation)
960
+
961
+ return conversation, mm_tracker.all_mm_data()
962
+
963
+
964
+ def apply_hf_chat_template(
965
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
966
+ conversation: List[ConversationMessage],
967
+ chat_template: Optional[str],
968
+ *,
969
+ tokenize: bool = False, # Different from HF's default
970
+ **kwargs: Any,
971
+ ) -> str:
972
+ if chat_template is None and tokenizer.chat_template is None:
973
+ raise ValueError(
974
+ "As of transformers v4.44, default chat template is no longer "
975
+ "allowed, so you must provide a chat template if the tokenizer "
976
+ "does not define one.")
977
+
978
+ return tokenizer.apply_chat_template(
979
+ conversation=conversation, # type: ignore[arg-type]
980
+ chat_template=chat_template,
981
+ tokenize=tokenize,
982
+ **kwargs,
983
+ )
984
+
985
+
986
+ def apply_mistral_chat_template(
987
+ tokenizer: MistralTokenizer,
988
+ messages: List[ChatCompletionMessageParam],
989
+ chat_template: Optional[str] = None,
990
+ **kwargs: Any,
991
+ ) -> List[int]:
992
+ if chat_template is not None:
993
+ logger.warning_once(
994
+ "'chat_template' cannot be overridden for mistral tokenizer.")
995
+ if "add_generation_prompt" in kwargs:
996
+ logger.warning_once(
997
+ "'add_generation_prompt' is not supported for mistral tokenizer, "
998
+ "so it will be ignored.")
999
+ if "continue_final_message" in kwargs:
1000
+ logger.warning_once(
1001
+ "'continue_final_message' is not supported for mistral tokenizer, "
1002
+ "so it will be ignored.")
1003
+
1004
+ return tokenizer.apply_chat_template(
1005
+ messages=messages,
1006
+ **kwargs,
1007
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/launcher.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import signal
5
+ from http import HTTPStatus
6
+ from typing import Any
7
+
8
+ import uvicorn
9
+ from fastapi import FastAPI, Request, Response
10
+
11
+ from vllm import envs
12
+ from vllm.engine.async_llm_engine import AsyncEngineDeadError
13
+ from vllm.engine.multiprocessing import MQEngineDeadError
14
+ from vllm.logger import init_logger
15
+ from vllm.utils import find_process_using_port
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
21
+ logger.info("Available routes are:")
22
+ for route in app.routes:
23
+ methods = getattr(route, "methods", None)
24
+ path = getattr(route, "path", None)
25
+
26
+ if methods is None or path is None:
27
+ continue
28
+
29
+ logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
30
+
31
+ config = uvicorn.Config(app, **uvicorn_kwargs)
32
+ server = uvicorn.Server(config)
33
+ _add_shutdown_handlers(app, server)
34
+
35
+ loop = asyncio.get_running_loop()
36
+
37
+ server_task = loop.create_task(server.serve())
38
+
39
+ def signal_handler() -> None:
40
+ # prevents the uvicorn signal handler to exit early
41
+ server_task.cancel()
42
+
43
+ async def dummy_shutdown() -> None:
44
+ pass
45
+
46
+ loop.add_signal_handler(signal.SIGINT, signal_handler)
47
+ loop.add_signal_handler(signal.SIGTERM, signal_handler)
48
+
49
+ try:
50
+ await server_task
51
+ return dummy_shutdown()
52
+ except asyncio.CancelledError:
53
+ port = uvicorn_kwargs["port"]
54
+ process = find_process_using_port(port)
55
+ if process is not None:
56
+ logger.debug(
57
+ "port %s is used by process %s launched with command:\n%s",
58
+ port, process, " ".join(process.cmdline()))
59
+ logger.info("Shutting down FastAPI HTTP server.")
60
+ return server.shutdown()
61
+
62
+
63
+ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
64
+ """Adds handlers for fatal errors that should crash the server"""
65
+
66
+ @app.exception_handler(RuntimeError)
67
+ async def runtime_error_handler(request: Request, __):
68
+ """On generic runtime error, check to see if the engine has died.
69
+ It probably has, in which case the server will no longer be able to
70
+ handle requests. Trigger a graceful shutdown with a SIGTERM."""
71
+ engine = request.app.state.engine_client
72
+ if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
73
+ and not engine.is_running):
74
+ logger.fatal("AsyncLLMEngine has failed, terminating server "
75
+ "process")
76
+ # See discussions here on shutting down a uvicorn server
77
+ # https://github.com/encode/uvicorn/discussions/1103
78
+ # In this case we cannot await the server shutdown here because
79
+ # this handler must first return to close the connection for
80
+ # this request.
81
+ server.should_exit = True
82
+
83
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
84
+
85
+ @app.exception_handler(AsyncEngineDeadError)
86
+ async def async_engine_dead_handler(_, __):
87
+ """Kill the server if the async engine is already dead. It will
88
+ not handle any further requests."""
89
+ if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
90
+ logger.fatal("AsyncLLMEngine is already dead, terminating server "
91
+ "process")
92
+ server.should_exit = True
93
+
94
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
95
+
96
+ @app.exception_handler(MQEngineDeadError)
97
+ async def mq_engine_dead_handler(_, __):
98
+ """Kill the server if the mq engine is already dead. It will
99
+ not handle any further requests."""
100
+ if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
101
+ logger.fatal("MQLLMEngine is already dead, terminating server "
102
+ "process")
103
+ server.should_exit = True
104
+
105
+ return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import itertools
4
+ import warnings
5
+ from contextlib import contextmanager
6
+ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
7
+ Tuple, Type, Union, cast, overload)
8
+
9
+ import cloudpickle
10
+ import torch
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+ from typing_extensions import TypeVar, deprecated
14
+
15
+ from vllm import envs
16
+ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
17
+ BeamSearchSequence, get_beam_search_score)
18
+ from vllm.config import CompilationConfig
19
+ from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
20
+ TaskOption)
21
+ from vllm.engine.llm_engine import LLMEngine
22
+ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
+ ChatTemplateContentFormatOption,
24
+ apply_hf_chat_template,
25
+ apply_mistral_chat_template,
26
+ parse_chat_messages,
27
+ resolve_chat_template_content_format)
28
+ from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
+ from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
+ from vllm.logger import init_logger
31
+ from vllm.lora.request import LoRARequest
32
+ from vllm.model_executor.guided_decoding.guided_fields import (
33
+ GuidedDecodingRequest, LLMGuidedOptions)
34
+ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
35
+ PoolingRequestOutput, RequestOutput,
36
+ ScoringRequestOutput)
37
+ from vllm.pooling_params import PoolingParams
38
+ from vllm.prompt_adapter.request import PromptAdapterRequest
39
+ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
40
+ RequestOutputKind, SamplingParams)
41
+ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
+ get_cached_tokenizer)
43
+ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
44
+ from vllm.usage.usage_lib import UsageContext
45
+ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
46
+
47
+ logger = init_logger(__name__)
48
+
49
+ _R = TypeVar("_R", default=Any)
50
+
51
+
52
+ class LLM:
53
+ """An LLM for generating texts from given prompts and sampling parameters.
54
+
55
+ This class includes a tokenizer, a language model (possibly distributed
56
+ across multiple GPUs), and GPU memory space allocated for intermediate
57
+ states (aka KV cache). Given a batch of prompts and sampling parameters,
58
+ this class generates texts from the model, using an intelligent batching
59
+ mechanism and efficient memory management.
60
+
61
+ Args:
62
+ model: The name or path of a HuggingFace Transformers model.
63
+ tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
+ tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
65
+ if available, and "slow" will always use the slow tokenizer.
66
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
67
+ detokenizer. Expect valid prompt_token_ids and None for prompt
68
+ from the input.
69
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
70
+ downloading the model and tokenizer.
71
+ allowed_local_media_path: Allowing API requests to read local images
72
+ or videos from directories specified by the server file system.
73
+ This is a security risk. Should only be enabled in trusted
74
+ environments.
75
+ tensor_parallel_size: The number of GPUs to use for distributed
76
+ execution with tensor parallelism.
77
+ dtype: The data type for the model weights and activations. Currently,
78
+ we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
79
+ the `torch_dtype` attribute specified in the model config file.
80
+ However, if the `torch_dtype` in the config is `float32`, we will
81
+ use `float16` instead.
82
+ quantization: The method used to quantize the model weights. Currently,
83
+ we support "awq", "gptq", and "fp8" (experimental).
84
+ If None, we first check the `quantization_config` attribute in the
85
+ model config file. If that is None, we assume the model weights are
86
+ not quantized and use `dtype` to determine the data type of
87
+ the weights.
88
+ revision: The specific model version to use. It can be a branch name,
89
+ a tag name, or a commit id.
90
+ tokenizer_revision: The specific tokenizer version to use. It can be a
91
+ branch name, a tag name, or a commit id.
92
+ seed: The seed to initialize the random number generator for sampling.
93
+ gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
94
+ reserve for the model weights, activations, and KV cache. Higher
95
+ values will increase the KV cache size and thus improve the model's
96
+ throughput. However, if the value is too high, it may cause out-of-
97
+ memory (OOM) errors.
98
+ swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
99
+ This can be used for temporarily storing the states of the requests
100
+ when their `best_of` sampling parameters are larger than 1. If all
101
+ requests will have `best_of=1`, you can safely set this to 0.
102
+ Otherwise, too small values may cause out-of-memory (OOM) errors.
103
+ cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
104
+ the model weights. This virtually increases the GPU memory space
105
+ you can use to hold the model weights, at the cost of CPU-GPU data
106
+ transfer for every forward pass.
107
+ enforce_eager: Whether to enforce eager execution. If True, we will
108
+ disable CUDA graph and always execute the model in eager mode.
109
+ If False, we will use CUDA graph and eager execution in hybrid.
110
+ max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
111
+ When a sequence has context length larger than this, we fall back
112
+ to eager mode. Additionally for encoder-decoder models, if the
113
+ sequence length of the encoder input is larger than this, we fall
114
+ back to the eager mode.
115
+ disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
116
+ disable_async_output_proc: Disable async output processing.
117
+ This may result in lower performance.
118
+ hf_overrides: If a dictionary, contains arguments to be forwarded to the
119
+ HuggingFace config. If a callable, it is called to update the
120
+ HuggingFace config.
121
+ compilation_config: Either an integer or a dictionary. If it is an
122
+ integer, it is used as the level of compilation optimization. If it
123
+ is a dictionary, it can specify the full compilation configuration.
124
+ **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
125
+ :ref:`engine-args`)
126
+
127
+ Note:
128
+ This class is intended to be used for offline inference. For online
129
+ serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
130
+ """
131
+
132
+ DEPRECATE_LEGACY: ClassVar[bool] = True
133
+ """A flag to toggle whether to deprecate the legacy generate/encode API."""
134
+
135
+ DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
136
+ """
137
+ A flag to toggle whether to deprecate positional arguments in
138
+ :meth:`LLM.__init__`.
139
+ """
140
+
141
+ @classmethod
142
+ @contextmanager
143
+ def deprecate_legacy_api(cls):
144
+ cls.DEPRECATE_LEGACY = True
145
+
146
+ yield
147
+
148
+ cls.DEPRECATE_LEGACY = False
149
+
150
+ @deprecate_args(
151
+ start_index=2, # Ignore self and model
152
+ is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
153
+ additional_message=(
154
+ "All positional arguments other than `model` will be "
155
+ "replaced with keyword arguments in an upcoming version."),
156
+ )
157
+ def __init__(
158
+ self,
159
+ model: str,
160
+ tokenizer: Optional[str] = None,
161
+ tokenizer_mode: str = "auto",
162
+ skip_tokenizer_init: bool = False,
163
+ trust_remote_code: bool = False,
164
+ allowed_local_media_path: str = "",
165
+ tensor_parallel_size: int = 1,
166
+ dtype: str = "auto",
167
+ quantization: Optional[str] = None,
168
+ revision: Optional[str] = None,
169
+ tokenizer_revision: Optional[str] = None,
170
+ seed: int = 0,
171
+ gpu_memory_utilization: float = 0.9,
172
+ swap_space: float = 4,
173
+ cpu_offload_gb: float = 0,
174
+ enforce_eager: Optional[bool] = None,
175
+ max_seq_len_to_capture: int = 8192,
176
+ disable_custom_all_reduce: bool = False,
177
+ disable_async_output_proc: bool = False,
178
+ hf_overrides: Optional[HfOverrides] = None,
179
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
180
+ # After positional args are removed, move this right below `model`
181
+ task: TaskOption = "auto",
182
+ override_pooler_config: Optional[PoolerConfig] = None,
183
+ compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
184
+ **kwargs,
185
+ ) -> None:
186
+ '''
187
+ LLM constructor.
188
+
189
+ Note: if enforce_eager is unset (enforce_eager is None)
190
+ it defaults to False.
191
+ '''
192
+
193
+ if "disable_log_stats" not in kwargs:
194
+ kwargs["disable_log_stats"] = True
195
+
196
+ if "worker_cls" in kwargs:
197
+ worker_cls = kwargs["worker_cls"]
198
+ # if the worker_cls is not qualified string name,
199
+ # we serialize it using cloudpickle to avoid pickling issues
200
+ if isinstance(worker_cls, type):
201
+ kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
202
+
203
+ if compilation_config is not None:
204
+ if isinstance(compilation_config, (int, dict)):
205
+ compilation_config_instance = CompilationConfig.from_cli(
206
+ str(compilation_config))
207
+ else:
208
+ compilation_config_instance = compilation_config
209
+ else:
210
+ compilation_config_instance = None
211
+
212
+ engine_args = EngineArgs(
213
+ model=model,
214
+ task=task,
215
+ tokenizer=tokenizer,
216
+ tokenizer_mode=tokenizer_mode,
217
+ skip_tokenizer_init=skip_tokenizer_init,
218
+ trust_remote_code=trust_remote_code,
219
+ allowed_local_media_path=allowed_local_media_path,
220
+ tensor_parallel_size=tensor_parallel_size,
221
+ dtype=dtype,
222
+ quantization=quantization,
223
+ revision=revision,
224
+ tokenizer_revision=tokenizer_revision,
225
+ seed=seed,
226
+ gpu_memory_utilization=gpu_memory_utilization,
227
+ swap_space=swap_space,
228
+ cpu_offload_gb=cpu_offload_gb,
229
+ enforce_eager=enforce_eager,
230
+ max_seq_len_to_capture=max_seq_len_to_capture,
231
+ disable_custom_all_reduce=disable_custom_all_reduce,
232
+ disable_async_output_proc=disable_async_output_proc,
233
+ hf_overrides=hf_overrides,
234
+ mm_processor_kwargs=mm_processor_kwargs,
235
+ override_pooler_config=override_pooler_config,
236
+ compilation_config=compilation_config_instance,
237
+ **kwargs,
238
+ )
239
+ # Logic to switch between engines is done at runtime instead of import
240
+ # to avoid import order issues
241
+ self.engine_class = self.get_engine_class()
242
+ self.llm_engine = self.engine_class.from_engine_args(
243
+ engine_args, usage_context=UsageContext.LLM_CLASS)
244
+
245
+ self.request_counter = Counter()
246
+
247
+ @staticmethod
248
+ def get_engine_class() -> Type[LLMEngine]:
249
+ if envs.VLLM_USE_V1:
250
+ # Lazy import: the v1 package isn't distributed
251
+ from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
252
+ return V1LLMEngine # type: ignore
253
+ return LLMEngine
254
+
255
+ def get_tokenizer(self) -> AnyTokenizer:
256
+ return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
257
+
258
+ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
259
+ tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
260
+
261
+ # While CachedTokenizer is dynamic, have no choice but
262
+ # compare class name. Misjudgment will arise from
263
+ # user-defined tokenizer started with 'Cached'
264
+ if tokenizer.__class__.__name__.startswith("Cached"):
265
+ tokenizer_group.tokenizer = tokenizer
266
+ else:
267
+ tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268
+
269
+ def get_default_sampling_params(self) -> SamplingParams:
270
+ diff_sampling_param = (
271
+ self.llm_engine.model_config.get_diff_sampling_param())
272
+ if diff_sampling_param:
273
+ return SamplingParams.from_optional(**diff_sampling_param)
274
+ return SamplingParams()
275
+
276
+ @overload
277
+ def generate(
278
+ self,
279
+ prompts: Union[PromptType, Sequence[PromptType]],
280
+ /,
281
+ sampling_params: Optional[Union[SamplingParams,
282
+ Sequence[SamplingParams]]] = None,
283
+ *,
284
+ use_tqdm: bool = True,
285
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
286
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
287
+ guided_options_request: Optional[Union[LLMGuidedOptions,
288
+ GuidedDecodingRequest]] = None,
289
+ ) -> List[RequestOutput]:
290
+ ...
291
+
292
+ @overload # LEGACY: single (prompt + optional token ids)
293
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
+ def generate(
295
+ self,
296
+ prompts: str,
297
+ sampling_params: Optional[Union[SamplingParams,
298
+ List[SamplingParams]]] = None,
299
+ prompt_token_ids: Optional[List[int]] = None,
300
+ use_tqdm: bool = True,
301
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
302
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
303
+ guided_options_request: Optional[Union[LLMGuidedOptions,
304
+ GuidedDecodingRequest]] = None,
305
+ ) -> List[RequestOutput]:
306
+ ...
307
+
308
+ @overload # LEGACY: multi (prompt + optional token ids)
309
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
+ def generate(
311
+ self,
312
+ prompts: List[str],
313
+ sampling_params: Optional[Union[SamplingParams,
314
+ List[SamplingParams]]] = None,
315
+ prompt_token_ids: Optional[List[List[int]]] = None,
316
+ use_tqdm: bool = True,
317
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
318
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
319
+ guided_options_request: Optional[Union[LLMGuidedOptions,
320
+ GuidedDecodingRequest]] = None,
321
+ ) -> List[RequestOutput]:
322
+ ...
323
+
324
+ @overload # LEGACY: single (token ids + optional prompt)
325
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
+ def generate(
327
+ self,
328
+ prompts: Optional[str] = None,
329
+ sampling_params: Optional[Union[SamplingParams,
330
+ List[SamplingParams]]] = None,
331
+ *,
332
+ prompt_token_ids: List[int],
333
+ use_tqdm: bool = True,
334
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
335
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
336
+ guided_options_request: Optional[Union[LLMGuidedOptions,
337
+ GuidedDecodingRequest]] = None,
338
+ ) -> List[RequestOutput]:
339
+ ...
340
+
341
+ @overload # LEGACY: multi (token ids + optional prompt)
342
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
+ def generate(
344
+ self,
345
+ prompts: Optional[List[str]] = None,
346
+ sampling_params: Optional[Union[SamplingParams,
347
+ List[SamplingParams]]] = None,
348
+ *,
349
+ prompt_token_ids: List[List[int]],
350
+ use_tqdm: bool = True,
351
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
352
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
353
+ guided_options_request: Optional[Union[LLMGuidedOptions,
354
+ GuidedDecodingRequest]] = None,
355
+ ) -> List[RequestOutput]:
356
+ ...
357
+
358
+ @overload # LEGACY: single or multi token ids [pos-only]
359
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
+ def generate(
361
+ self,
362
+ prompts: None,
363
+ sampling_params: None,
364
+ prompt_token_ids: Union[List[int], List[List[int]]],
365
+ use_tqdm: bool = True,
366
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
367
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
368
+ guided_options_request: Optional[Union[LLMGuidedOptions,
369
+ GuidedDecodingRequest]] = None,
370
+ ) -> List[RequestOutput]:
371
+ ...
372
+
373
+ @deprecate_kwargs(
374
+ "prompt_token_ids",
375
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
+ additional_message="Please use the 'prompts' parameter instead.",
377
+ )
378
+ def generate(
379
+ self,
380
+ prompts: Union[Union[PromptType, Sequence[PromptType]],
381
+ Optional[Union[str, List[str]]]] = None,
382
+ sampling_params: Optional[Union[SamplingParams,
383
+ Sequence[SamplingParams]]] = None,
384
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
385
+ use_tqdm: bool = True,
386
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
387
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
+ guided_options_request: Optional[Union[LLMGuidedOptions,
389
+ GuidedDecodingRequest]] = None,
390
+ priority: Optional[List[int]] = None,
391
+ ) -> List[RequestOutput]:
392
+ """Generates the completions for the input prompts.
393
+
394
+ This class automatically batches the given prompts, considering
395
+ the memory constraint. For the best performance, put all of your prompts
396
+ into a single list and pass it to this method.
397
+
398
+ Args:
399
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
400
+ for batch inference. See :class:`~vllm.inputs.PromptType`
401
+ for more details about the format of each prompts.
402
+ sampling_params: The sampling parameters for text generation. If
403
+ None, we use the default sampling parameters.
404
+ When it is a single value, it is applied to every prompt.
405
+ When it is a list, the list must have the same length as the
406
+ prompts and it is paired one by one with the prompt.
407
+ use_tqdm: Whether to use tqdm to display the progress bar.
408
+ lora_request: LoRA request to use for generation, if any.
409
+ prompt_adapter_request: Prompt Adapter request to use for
410
+ generation, if any.
411
+ priority: The priority of the requests, if any.
412
+ Only applicable when priority scheduling policy is enabled.
413
+
414
+ Returns:
415
+ A list of ``RequestOutput`` objects containing the
416
+ generated completions in the same order as the input prompts.
417
+
418
+ Note:
419
+ Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
420
+ considered legacy and may be deprecated in the future. You should
421
+ instead pass them via the ``inputs`` parameter.
422
+ """
423
+ runner_type = self.llm_engine.model_config.runner_type
424
+ if runner_type != "generate":
425
+ messages = [
426
+ "LLM.generate() is only supported for (conditional) generation "
427
+ "models (XForCausalLM, XForConditionalGeneration).",
428
+ ]
429
+
430
+ supported_runner_types = self.llm_engine.model_config \
431
+ .supported_runner_types
432
+ if "generate" in supported_runner_types:
433
+ messages.append(
434
+ "Your model supports the 'generate' runner, but is "
435
+ f"currently initialized for the '{runner_type}' runner. "
436
+ "Please initialize vLLM using `--task generate`.")
437
+
438
+ raise ValueError(" ".join(messages))
439
+
440
+ if prompt_token_ids is not None:
441
+ parsed_prompts = self._convert_v1_inputs(
442
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
443
+ prompt_token_ids=prompt_token_ids,
444
+ )
445
+ else:
446
+ parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
447
+ prompts)
448
+
449
+ if isinstance(guided_options_request, dict):
450
+ if len(guided_options_request) > 1:
451
+ raise ValueError(
452
+ "You can only use one guided decoding but multiple is "
453
+ f"specified: {guided_options_request}")
454
+ guided_options_request = GuidedDecodingRequest(
455
+ **guided_options_request)
456
+
457
+ if sampling_params is None:
458
+ # Use default sampling params.
459
+ sampling_params = self.get_default_sampling_params()
460
+
461
+ self._validate_and_add_requests(
462
+ prompts=parsed_prompts,
463
+ params=sampling_params,
464
+ lora_request=lora_request,
465
+ prompt_adapter_request=prompt_adapter_request,
466
+ guided_options=guided_options_request,
467
+ priority=priority)
468
+
469
+ outputs = self._run_engine(use_tqdm=use_tqdm)
470
+ return self.engine_class.validate_outputs(outputs, RequestOutput)
471
+
472
+ def collective_rpc(self,
473
+ method: Union[str, Callable[..., _R]],
474
+ timeout: Optional[float] = None,
475
+ args: Tuple = (),
476
+ kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
477
+ """
478
+ Execute an RPC call on all workers.
479
+
480
+ Args:
481
+ method: Name of the worker method to execute, or a callable that
482
+ is serialized and sent to all workers to execute.
483
+
484
+ If the method is a callable, it should accept an additional
485
+ `self` argument, in addition to the arguments passed in `args`
486
+ and `kwargs`. The `self` argument will be the worker object.
487
+ timeout: Maximum time in seconds to wait for execution. Raises a
488
+ :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
489
+ args: Positional arguments to pass to the worker method.
490
+ kwargs: Keyword arguments to pass to the worker method.
491
+
492
+ Returns:
493
+ A list containing the results from each worker.
494
+
495
+ Note:
496
+ It is recommended to use this API to only pass control messages,
497
+ and set up data-plane communication to pass data.
498
+ """
499
+ executor = self.llm_engine.model_executor
500
+ return executor.collective_rpc(method, timeout, args, kwargs)
501
+
502
+ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
503
+ """
504
+ Run a function directly on the model inside each worker,
505
+ returning the result for each of them.
506
+ """
507
+ executor = self.llm_engine.model_executor
508
+ return executor.apply_model(func)
509
+
510
+ def beam_search(
511
+ self,
512
+ prompts: List[Union[TokensPrompt, TextPrompt]],
513
+ params: BeamSearchParams,
514
+ ) -> List[BeamSearchOutput]:
515
+ """
516
+ Generate sequences using beam search.
517
+
518
+ Args:
519
+ prompts: A list of prompts. Each prompt can be a string or a list
520
+ of token IDs.
521
+ params: The beam search parameters.
522
+
523
+ TODO: how does beam search work together with length penalty, frequency
524
+ penalty, and stopping criteria, etc.?
525
+ """
526
+
527
+ beam_width = params.beam_width
528
+ max_tokens = params.max_tokens
529
+ temperature = params.temperature
530
+ ignore_eos = params.ignore_eos
531
+ length_penalty = params.length_penalty
532
+
533
+ def sort_beams_key(x: BeamSearchSequence) -> float:
534
+ return get_beam_search_score(x.tokens, x.cum_logprob,
535
+ tokenizer.eos_token_id,
536
+ length_penalty)
537
+
538
+ tokenizer = self.get_tokenizer()
539
+ # generate 2 * beam_width candidates at each step
540
+ # following the huggingface transformers implementation
541
+ # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
542
+ beam_search_params = SamplingParams(logprobs=2 * beam_width,
543
+ max_tokens=1,
544
+ temperature=temperature)
545
+ instances: List[BeamSearchInstance] = []
546
+
547
+ for prompt in prompts:
548
+ if is_token_prompt(prompt):
549
+ prompt_tokens = prompt["prompt_token_ids"]
550
+ else:
551
+ prompt_tokens = tokenizer.encode(prompt["prompt"])
552
+ instances.append(BeamSearchInstance(prompt_tokens))
553
+
554
+ for _ in range(max_tokens):
555
+ all_beams: List[BeamSearchSequence] = list(
556
+ sum((instance.beams for instance in instances), []))
557
+ pos = [0] + list(
558
+ itertools.accumulate(
559
+ len(instance.beams) for instance in instances))
560
+ instance_start_and_end: List[Tuple[int, int]] = list(
561
+ zip(pos[:-1], pos[1:]))
562
+
563
+ if len(all_beams) == 0:
564
+ break
565
+
566
+ prompts_batch = [
567
+ TokensPrompt(prompt_token_ids=beam.tokens)
568
+ for beam in all_beams
569
+ ]
570
+
571
+ # only runs for one step
572
+ # we don't need to use tqdm here
573
+ output = self.generate(prompts_batch,
574
+ sampling_params=beam_search_params,
575
+ use_tqdm=False)
576
+
577
+ for (start, end), instance in zip(instance_start_and_end,
578
+ instances):
579
+ instance_new_beams = []
580
+ for i in range(start, end):
581
+ current_beam = all_beams[i]
582
+ result = output[i]
583
+
584
+ if result.outputs[0].logprobs is not None:
585
+ # if `result.outputs[0].logprobs` is None, it means
586
+ # the sequence is completed because of the max-model-len
587
+ # or abortion. we don't need to add it to the new beams.
588
+ logprobs = result.outputs[0].logprobs[0]
589
+ for token_id, logprob_obj in logprobs.items():
590
+ new_beam = BeamSearchSequence(
591
+ tokens=current_beam.tokens + [token_id],
592
+ logprobs=current_beam.logprobs + [logprobs],
593
+ cum_logprob=current_beam.cum_logprob +
594
+ logprob_obj.logprob)
595
+
596
+ if token_id == tokenizer.eos_token_id and \
597
+ not ignore_eos:
598
+ instance.completed.append(new_beam)
599
+ else:
600
+ instance_new_beams.append(new_beam)
601
+ sorted_beams = sorted(instance_new_beams,
602
+ key=sort_beams_key,
603
+ reverse=True)
604
+ instance.beams = sorted_beams[:beam_width]
605
+
606
+ outputs = []
607
+ for instance in instances:
608
+ instance.completed.extend(instance.beams)
609
+ sorted_completed = sorted(instance.completed,
610
+ key=sort_beams_key,
611
+ reverse=True)
612
+ best_beams = sorted_completed[:beam_width]
613
+
614
+ for beam in best_beams:
615
+ beam.text = tokenizer.decode(beam.tokens)
616
+ outputs.append(BeamSearchOutput(sequences=best_beams))
617
+
618
+ return outputs
619
+
620
+ def chat(
621
+ self,
622
+ messages: Union[List[ChatCompletionMessageParam],
623
+ List[List[ChatCompletionMessageParam]]],
624
+ sampling_params: Optional[Union[SamplingParams,
625
+ List[SamplingParams]]] = None,
626
+ use_tqdm: bool = True,
627
+ lora_request: Optional[LoRARequest] = None,
628
+ chat_template: Optional[str] = None,
629
+ chat_template_content_format: ChatTemplateContentFormatOption = "auto",
630
+ add_generation_prompt: bool = True,
631
+ continue_final_message: bool = False,
632
+ tools: Optional[List[Dict[str, Any]]] = None,
633
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
634
+ ) -> List[RequestOutput]:
635
+ """
636
+ Generate responses for a chat conversation.
637
+
638
+ The chat conversation is converted into a text prompt using the
639
+ tokenizer and calls the :meth:`generate` method to generate the
640
+ responses.
641
+
642
+ Multi-modal inputs can be passed in the same way you would pass them
643
+ to the OpenAI API.
644
+
645
+ Args:
646
+ messages: A list of conversations or a single conversation.
647
+
648
+ - Each conversation is represented as a list of messages.
649
+ - Each message is a dictionary with 'role' and 'content' keys.
650
+
651
+ sampling_params: The sampling parameters for text generation.
652
+ If None, we use the default sampling parameters. When it
653
+ is a single value, it is applied to every prompt. When it
654
+ is a list, the list must have the same length as the
655
+ prompts and it is paired one by one with the prompt.
656
+ use_tqdm: Whether to use tqdm to display the progress bar.
657
+ lora_request: LoRA request to use for generation, if any.
658
+ chat_template: The template to use for structuring the chat.
659
+ If not provided, the model's default chat template will be used.
660
+ chat_template_content_format: The format to render message content.
661
+
662
+ - "string" will render the content as a string.
663
+ Example: ``"Who are you?"``
664
+ - "openai" will render the content as a list of dictionaries,
665
+ similar to OpenAI schema.
666
+ Example: ``[{"type": "text", "text": "Who are you?"}]``
667
+
668
+ add_generation_prompt: If True, adds a generation template
669
+ to each message.
670
+ continue_final_message: If True, continues the final message in
671
+ the conversation instead of starting a new one. Cannot be
672
+ ``True`` if ``add_generation_prompt`` is also ``True``.
673
+ mm_processor_kwargs: Multimodal processor kwarg overrides for this
674
+ chat request. Only used for offline requests.
675
+
676
+ Returns:
677
+ A list of ``RequestOutput`` objects containing the generated
678
+ responses in the same order as the input messages.
679
+ """
680
+ list_of_messages: List[List[ChatCompletionMessageParam]]
681
+
682
+ # Handle multi and single conversations
683
+ if is_list_of(messages, list):
684
+ # messages is List[List[...]]
685
+ list_of_messages = cast(List[List[ChatCompletionMessageParam]],
686
+ messages)
687
+ else:
688
+ # messages is List[...]
689
+ list_of_messages = [
690
+ cast(List[ChatCompletionMessageParam], messages)
691
+ ]
692
+
693
+ tokenizer = self.get_tokenizer()
694
+ model_config = self.llm_engine.get_model_config()
695
+ resolved_content_format = resolve_chat_template_content_format(
696
+ chat_template,
697
+ chat_template_content_format,
698
+ tokenizer,
699
+ )
700
+
701
+ prompts: List[Union[TokensPrompt, TextPrompt]] = []
702
+
703
+ for msgs in list_of_messages:
704
+ # NOTE: _parse_chat_message_content_parts() currently doesn't
705
+ # handle mm_processor_kwargs, since there is no implementation in
706
+ # the chat message parsing for it.
707
+ conversation, mm_data = parse_chat_messages(
708
+ msgs,
709
+ model_config,
710
+ tokenizer,
711
+ content_format=resolved_content_format,
712
+ )
713
+
714
+ prompt_data: Union[str, List[int]]
715
+ if isinstance(tokenizer, MistralTokenizer):
716
+ prompt_data = apply_mistral_chat_template(
717
+ tokenizer,
718
+ messages=msgs,
719
+ chat_template=chat_template,
720
+ add_generation_prompt=add_generation_prompt,
721
+ continue_final_message=continue_final_message,
722
+ tools=tools,
723
+ )
724
+ else:
725
+ prompt_data = apply_hf_chat_template(
726
+ tokenizer,
727
+ conversation=conversation,
728
+ chat_template=chat_template,
729
+ add_generation_prompt=add_generation_prompt,
730
+ continue_final_message=continue_final_message,
731
+ tools=tools,
732
+ )
733
+
734
+ prompt: Union[TokensPrompt, TextPrompt]
735
+ if is_list_of(prompt_data, int):
736
+ prompt = TokensPrompt(prompt_token_ids=prompt_data)
737
+ else:
738
+ prompt = TextPrompt(prompt=prompt_data)
739
+
740
+ if mm_data is not None:
741
+ prompt["multi_modal_data"] = mm_data
742
+
743
+ if mm_processor_kwargs is not None:
744
+ prompt["mm_processor_kwargs"] = mm_processor_kwargs
745
+
746
+ prompts.append(prompt)
747
+
748
+ return self.generate(
749
+ prompts,
750
+ sampling_params=sampling_params,
751
+ use_tqdm=use_tqdm,
752
+ lora_request=lora_request,
753
+ )
754
+
755
+ @overload
756
+ def encode(
757
+ self,
758
+ prompts: Union[PromptType, Sequence[PromptType]],
759
+ /,
760
+ pooling_params: Optional[Union[PoolingParams,
761
+ Sequence[PoolingParams]]] = None,
762
+ *,
763
+ use_tqdm: bool = True,
764
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
765
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
766
+ ) -> List[PoolingRequestOutput]:
767
+ ...
768
+
769
+ @overload # LEGACY: single (prompt + optional token ids)
770
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
771
+ def encode(
772
+ self,
773
+ prompts: str,
774
+ pooling_params: Optional[Union[PoolingParams,
775
+ Sequence[PoolingParams]]] = None,
776
+ prompt_token_ids: Optional[List[int]] = None,
777
+ use_tqdm: bool = True,
778
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
779
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
780
+ ) -> List[PoolingRequestOutput]:
781
+ ...
782
+
783
+ @overload # LEGACY: multi (prompt + optional token ids)
784
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
785
+ def encode(
786
+ self,
787
+ prompts: List[str],
788
+ pooling_params: Optional[Union[PoolingParams,
789
+ Sequence[PoolingParams]]] = None,
790
+ prompt_token_ids: Optional[List[List[int]]] = None,
791
+ use_tqdm: bool = True,
792
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
793
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
794
+ ) -> List[PoolingRequestOutput]:
795
+ ...
796
+
797
+ @overload # LEGACY: single (token ids + optional prompt)
798
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
799
+ def encode(
800
+ self,
801
+ prompts: Optional[str] = None,
802
+ pooling_params: Optional[Union[PoolingParams,
803
+ Sequence[PoolingParams]]] = None,
804
+ *,
805
+ prompt_token_ids: List[int],
806
+ use_tqdm: bool = True,
807
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
808
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
809
+ ) -> List[PoolingRequestOutput]:
810
+ ...
811
+
812
+ @overload # LEGACY: multi (token ids + optional prompt)
813
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
814
+ def encode(
815
+ self,
816
+ prompts: Optional[List[str]] = None,
817
+ pooling_params: Optional[Union[PoolingParams,
818
+ Sequence[PoolingParams]]] = None,
819
+ *,
820
+ prompt_token_ids: List[List[int]],
821
+ use_tqdm: bool = True,
822
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
823
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
824
+ ) -> List[PoolingRequestOutput]:
825
+ ...
826
+
827
+ @overload # LEGACY: single or multi token ids [pos-only]
828
+ @deprecated("'prompt_token_ids' will become part of 'prompts'")
829
+ def encode(
830
+ self,
831
+ prompts: None,
832
+ pooling_params: None,
833
+ prompt_token_ids: Union[List[int], List[List[int]]],
834
+ use_tqdm: bool = True,
835
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
836
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
837
+ ) -> List[PoolingRequestOutput]:
838
+ ...
839
+
840
+ @deprecate_kwargs(
841
+ "prompt_token_ids",
842
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
843
+ additional_message="Please use the 'prompts' parameter instead.",
844
+ )
845
+ def encode(
846
+ self,
847
+ prompts: Union[Union[PromptType, Sequence[PromptType]],
848
+ Optional[Union[str, List[str]]]] = None,
849
+ pooling_params: Optional[Union[PoolingParams,
850
+ Sequence[PoolingParams]]] = None,
851
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
852
+ use_tqdm: bool = True,
853
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
854
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
855
+ ) -> List[PoolingRequestOutput]:
856
+ """Apply pooling to the hidden states corresponding to the input
857
+ prompts.
858
+
859
+ This class automatically batches the given prompts, considering
860
+ the memory constraint. For the best performance, put all of your prompts
861
+ into a single list and pass it to this method.
862
+
863
+ Args:
864
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
865
+ for batch inference. See :class:`~vllm.inputs.PromptType`
866
+ for more details about the format of each prompts.
867
+ pooling_params: The pooling parameters for pooling. If None, we
868
+ use the default pooling parameters.
869
+ use_tqdm: Whether to use tqdm to display the progress bar.
870
+ lora_request: LoRA request to use for generation, if any.
871
+ prompt_adapter_request: Prompt Adapter request to use for
872
+ generation, if any.
873
+
874
+ Returns:
875
+ A list of ``PoolingRequestOutput`` objects containing the
876
+ pooled hidden states in the same order as the input prompts.
877
+
878
+ Note:
879
+ Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
880
+ considered legacy and may be deprecated in the future. You should
881
+ instead pass them via the ``inputs`` parameter.
882
+ """
883
+ runner_type = self.llm_engine.model_config.runner_type
884
+ if runner_type != "pooling":
885
+ messages = ["LLM.encode() is only supported for pooling models."]
886
+
887
+ supported_runner_types = self.llm_engine.model_config \
888
+ .supported_runner_types
889
+ if "pooling" in supported_runner_types:
890
+ messages.append(
891
+ "Your model supports the 'pooling' runner, but is "
892
+ f"currently initialized for the '{runner_type}' runner. "
893
+ "Please initialize vLLM using `--task embed`, "
894
+ "`--task classify`, `--task score` etc.")
895
+
896
+ raise ValueError(" ".join(messages))
897
+
898
+ if prompt_token_ids is not None:
899
+ parsed_prompts = self._convert_v1_inputs(
900
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
901
+ prompt_token_ids=prompt_token_ids,
902
+ )
903
+ else:
904
+ parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
905
+ prompts)
906
+
907
+ if pooling_params is None:
908
+ # Use default pooling params.
909
+ pooling_params = PoolingParams()
910
+
911
+ self._validate_and_add_requests(
912
+ prompts=parsed_prompts,
913
+ params=pooling_params,
914
+ lora_request=lora_request,
915
+ prompt_adapter_request=prompt_adapter_request,
916
+ )
917
+
918
+ outputs = self._run_engine(use_tqdm=use_tqdm)
919
+ return self.engine_class.validate_outputs(outputs,
920
+ PoolingRequestOutput)
921
+
922
+ def embed(
923
+ self,
924
+ prompts: Union[PromptType, Sequence[PromptType]],
925
+ /,
926
+ *,
927
+ use_tqdm: bool = True,
928
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
929
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
930
+ ) -> List[EmbeddingRequestOutput]:
931
+ """
932
+ Generate an embedding vector for each prompt.
933
+
934
+ This class automatically batches the given prompts, considering
935
+ the memory constraint. For the best performance, put all of your prompts
936
+ into a single list and pass it to this method.
937
+
938
+ Args:
939
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
940
+ for batch inference. See :class:`~vllm.inputs.PromptType`
941
+ for more details about the format of each prompts.
942
+ use_tqdm: Whether to use tqdm to display the progress bar.
943
+ lora_request: LoRA request to use for generation, if any.
944
+ prompt_adapter_request: Prompt Adapter request to use for
945
+ generation, if any.
946
+
947
+ Returns:
948
+ A list of ``EmbeddingRequestOutput`` objects containing the
949
+ embedding vectors in the same order as the input prompts.
950
+ """
951
+ if self.llm_engine.model_config.task != "embed":
952
+ raise ValueError(
953
+ "Embedding API is only enabled for `--task embed`")
954
+
955
+ items = self.encode(prompts,
956
+ use_tqdm=use_tqdm,
957
+ lora_request=lora_request,
958
+ prompt_adapter_request=prompt_adapter_request)
959
+
960
+ return [EmbeddingRequestOutput.from_base(item) for item in items]
961
+
962
+ def classify(
963
+ self,
964
+ prompts: Union[PromptType, Sequence[PromptType]],
965
+ /,
966
+ *,
967
+ use_tqdm: bool = True,
968
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
969
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
970
+ ) -> List[ClassificationRequestOutput]:
971
+ """
972
+ Generate class logits for each prompt.
973
+
974
+ This class automatically batches the given prompts, considering
975
+ the memory constraint. For the best performance, put all of your prompts
976
+ into a single list and pass it to this method.
977
+
978
+ Args:
979
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
980
+ for batch inference. See :class:`~vllm.inputs.PromptType`
981
+ for more details about the format of each prompts.
982
+ use_tqdm: Whether to use tqdm to display the progress bar.
983
+ lora_request: LoRA request to use for generation, if any.
984
+ prompt_adapter_request: Prompt Adapter request to use for
985
+ generation, if any.
986
+
987
+ Returns:
988
+ A list of ``ClassificationRequestOutput`` objects containing the
989
+ embedding vectors in the same order as the input prompts.
990
+ """
991
+ if self.llm_engine.model_config.task != "classify":
992
+ raise ValueError(
993
+ "Classification API is only enabled for `--task classify`")
994
+
995
+ items = self.encode(prompts,
996
+ use_tqdm=use_tqdm,
997
+ lora_request=lora_request,
998
+ prompt_adapter_request=prompt_adapter_request)
999
+
1000
+ return [ClassificationRequestOutput.from_base(item) for item in items]
1001
+
1002
+ def _embedding_score(
1003
+ self,
1004
+ tokenizer: AnyTokenizer,
1005
+ text_1: List[Union[str, TextPrompt, TokensPrompt]],
1006
+ text_2: List[Union[str, TextPrompt, TokensPrompt]],
1007
+ truncate_prompt_tokens: Optional[int] = None,
1008
+ use_tqdm: bool = True,
1009
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
1010
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1011
+ ) -> List[ScoringRequestOutput]:
1012
+
1013
+ encoded_output = self.encode(
1014
+ text_1 + text_2,
1015
+ use_tqdm=use_tqdm,
1016
+ lora_request=lora_request,
1017
+ prompt_adapter_request=prompt_adapter_request)
1018
+ encoded_output_1 = encoded_output[0:len(text_1)]
1019
+ encoded_output_2 = encoded_output[len(text_1):]
1020
+
1021
+ if len(encoded_output_1) == 1:
1022
+ encoded_output_1 = encoded_output_1 * len(encoded_output_2)
1023
+
1024
+ output_pairs = [(t1, t2)
1025
+ for t1, t2 in zip(encoded_output_1, encoded_output_2)]
1026
+
1027
+ scores = []
1028
+ scorer = torch.nn.CosineSimilarity(0)
1029
+
1030
+ for embed_1, embed_2 in output_pairs:
1031
+ pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
1032
+
1033
+ if (pad_token_id := getattr(tokenizer, "pad_token_id",
1034
+ None)) is not None:
1035
+ tokens = embed_1.prompt_token_ids + [
1036
+ pad_token_id
1037
+ ] + embed_2.prompt_token_ids
1038
+ else:
1039
+ tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
1040
+
1041
+ scores.append(
1042
+ PoolingRequestOutput(
1043
+ request_id=f"{embed_1.request_id}_{embed_2.request_id}",
1044
+ outputs=pair_score,
1045
+ prompt_token_ids=tokens,
1046
+ finished=True))
1047
+
1048
+ items = self.engine_class.validate_outputs(scores,
1049
+ PoolingRequestOutput)
1050
+ return [ScoringRequestOutput.from_base(item) for item in items]
1051
+
1052
+ def _cross_encoding_score(
1053
+ self,
1054
+ tokenizer: Union[AnyTokenizer],
1055
+ text_1: List[Union[str, TextPrompt, TokensPrompt]],
1056
+ text_2: List[Union[str, TextPrompt, TokensPrompt]],
1057
+ truncate_prompt_tokens: Optional[int] = None,
1058
+ use_tqdm: bool = True,
1059
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
1060
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1061
+ ) -> List[ScoringRequestOutput]:
1062
+
1063
+ if isinstance(tokenizer, MistralTokenizer):
1064
+ raise ValueError(
1065
+ "Score API is only enabled for `--task embed or score`")
1066
+
1067
+ if len(text_1) == 1:
1068
+ text_1 = text_1 * len(text_2)
1069
+
1070
+ input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
1071
+
1072
+ pooling_params = PoolingParams()
1073
+
1074
+ tokenization_kwargs: Dict[str, Any] = {}
1075
+ if truncate_prompt_tokens is not None:
1076
+ tokenization_kwargs["truncation"] = True
1077
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
1078
+
1079
+ parsed_prompts = []
1080
+
1081
+ for q, t in input_pairs:
1082
+ prompt_inputs = tokenizer(text=q,
1083
+ text_pair=t,
1084
+ **tokenization_kwargs)
1085
+ engine_prompt = TokensPrompt(
1086
+ prompt_token_ids=prompt_inputs["input_ids"],
1087
+ token_type_ids=prompt_inputs.get("token_type_ids"))
1088
+ parsed_prompts.append(engine_prompt)
1089
+
1090
+ self._validate_and_add_requests(
1091
+ prompts=parsed_prompts,
1092
+ params=pooling_params,
1093
+ lora_request=lora_request,
1094
+ prompt_adapter_request=prompt_adapter_request,
1095
+ )
1096
+
1097
+ outputs = self._run_engine(use_tqdm=use_tqdm)
1098
+ items = self.engine_class.validate_outputs(outputs,
1099
+ PoolingRequestOutput)
1100
+
1101
+ return [ScoringRequestOutput.from_base(item) for item in items]
1102
+
1103
+ def score(
1104
+ self,
1105
+ text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
1106
+ text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
1107
+ /,
1108
+ *,
1109
+ truncate_prompt_tokens: Optional[int] = None,
1110
+ use_tqdm: bool = True,
1111
+ lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
1112
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1113
+ ) -> List[ScoringRequestOutput]:
1114
+ """Generate similarity scores for all pairs ``<text,text_pair>``.
1115
+
1116
+ The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
1117
+ In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
1118
+ times to pair with the ``text_2`` sentences.
1119
+ The input pairs are used to build a list of prompts for the
1120
+ cross encoder model. This class automatically batches the prompts,
1121
+ considering the memory constraint. For the best performance, put all
1122
+ of your texts into a single list and pass it to this method.
1123
+
1124
+ Args:
1125
+ text_1: can be a single prompt or a list of prompts, in which
1126
+ case it has to have the same length as the ``text_2`` list
1127
+ text_2: The texts to pair with the query to form the input
1128
+ to the LLM. See :class:`~vllm.inputs.PromptType` for
1129
+ more details about the format of each prompts.
1130
+ use_tqdm: Whether to use tqdm to display the progress bar.
1131
+ lora_request: LoRA request to use for generation, if any.
1132
+ prompt_adapter_request: Prompt Adapter request to use for
1133
+ generation, if any.
1134
+
1135
+ Returns:
1136
+ A list of ``ScoringRequestOutput`` objects containing the
1137
+ generated scores in the same order as the input prompts.
1138
+ """
1139
+ runner_type = self.llm_engine.model_config.runner_type
1140
+ if runner_type != "pooling":
1141
+ messages = ["LLM.score() is only supported for pooling models."]
1142
+
1143
+ supported_runner_types = self.llm_engine.model_config \
1144
+ .supported_runner_types
1145
+ if "pooling" in supported_runner_types:
1146
+ messages.append(
1147
+ "Your model supports the 'pooling' runner, but is "
1148
+ f"currently initialized for the '{runner_type}' runner. "
1149
+ "Please initialize vLLM using `--task embed`, "
1150
+ "`--task classify`, `--task score` etc.")
1151
+
1152
+ raise ValueError(" ".join(messages))
1153
+
1154
+ if self.llm_engine.model_config.task not in ("embed", "score"):
1155
+ raise ValueError(
1156
+ "Score API is only enabled for `--task embed or --task score`")
1157
+
1158
+ # the tokenizer for models such as
1159
+ # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
1160
+ # lists of tokens to the `text` and `text_pair` kwargs
1161
+ tokenizer = self.llm_engine.get_tokenizer()
1162
+
1163
+ def ensure_str(prompt: SingletonPrompt):
1164
+ if isinstance(prompt, dict):
1165
+ if "multi_modal_data" in prompt:
1166
+ raise ValueError("Multi-modal prompt is not "
1167
+ "supported for scoring")
1168
+ elif "prompt_token_ids" in prompt:
1169
+ prompt = tokenizer.decode(
1170
+ cast(TokensPrompt, prompt)["prompt_token_ids"])
1171
+ elif "prompt" in prompt:
1172
+ prompt = cast(TextPrompt, prompt)["prompt"]
1173
+ assert type(prompt) is str
1174
+ return prompt
1175
+
1176
+ if isinstance(text_1, (str, dict)):
1177
+ # Convert a single prompt to a list.
1178
+ text_1 = [text_1]
1179
+ text_1 = [ensure_str(t) for t in text_1]
1180
+
1181
+ if isinstance(text_2, (str, dict)):
1182
+ # Convert a single prompt to a list.
1183
+ text_2 = [text_2]
1184
+ text_2 = [ensure_str(t) for t in text_2]
1185
+
1186
+ if len(text_1) > 1 and len(text_1) != len(text_2):
1187
+ raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
1188
+ if len(text_1) == 0:
1189
+ raise ValueError("At least one text element must be given")
1190
+ if len(text_2) == 0:
1191
+ raise ValueError("At least one text_pair element must be given")
1192
+
1193
+ if self.llm_engine.model_config.is_cross_encoder:
1194
+ return self._cross_encoding_score(tokenizer, text_1, text_2,
1195
+ truncate_prompt_tokens, use_tqdm,
1196
+ lora_request,
1197
+ prompt_adapter_request)
1198
+ else:
1199
+ return self._embedding_score(tokenizer, text_1, text_2,
1200
+ truncate_prompt_tokens, use_tqdm,
1201
+ lora_request, prompt_adapter_request)
1202
+
1203
+ def start_profile(self) -> None:
1204
+ self.llm_engine.start_profile()
1205
+
1206
+ def stop_profile(self) -> None:
1207
+ self.llm_engine.stop_profile()
1208
+
1209
+ def reset_prefix_cache(self) -> bool:
1210
+ return self.llm_engine.reset_prefix_cache()
1211
+
1212
+ def sleep(self, level: int = 1):
1213
+ """
1214
+ Put the engine to sleep. The engine should not process any requests.
1215
+ The caller should guarantee that no requests are being processed
1216
+ during the sleep period, before `wake_up` is called.
1217
+
1218
+ :param level: The sleep level. Level 1 sleep will offload the model
1219
+ weights and discard the kv cache. The content of kv cache is
1220
+ forgotten. Level 1 sleep is good for sleeping and waking up the
1221
+ engine to run the same model again. The model weights are backed
1222
+ up in CPU memory. Please make sure there's enough CPU memory to
1223
+ store the model weights. Level 2 sleep will discard both the model
1224
+ weights and the kv cache. The content of both the model weights
1225
+ and kv cache is forgotten. Level 2 sleep is good for sleeping and
1226
+ waking up the engine to run a different model or update the model,
1227
+ where previous model weights are not needed. It reduces CPU memory
1228
+ pressure.
1229
+ """
1230
+ self.reset_prefix_cache()
1231
+ self.llm_engine.sleep(level=level)
1232
+
1233
+ def wake_up(self):
1234
+ """
1235
+ Wake up the engine from sleep mode. See the :meth:`sleep` method
1236
+ for more details."""
1237
+ self.llm_engine.wake_up()
1238
+
1239
+ # LEGACY
1240
+ def _convert_v1_inputs(
1241
+ self,
1242
+ prompts: Optional[Union[str, List[str]]],
1243
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
1244
+ ):
1245
+ # skip_tokenizer_init is now checked in engine
1246
+
1247
+ if prompts is not None:
1248
+ prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
1249
+ if prompt_token_ids is not None:
1250
+ prompt_token_ids = [
1251
+ p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
1252
+ ]
1253
+
1254
+ num_requests = None
1255
+ if prompts is not None:
1256
+ num_requests = len(prompts)
1257
+ if prompt_token_ids is not None:
1258
+ if (num_requests is not None
1259
+ and num_requests != len(prompt_token_ids)):
1260
+ raise ValueError("The lengths of prompts and prompt_token_ids "
1261
+ "must be the same.")
1262
+
1263
+ num_requests = len(prompt_token_ids)
1264
+ if num_requests is None:
1265
+ raise ValueError("Either prompts or prompt_token_ids must be "
1266
+ "provided.")
1267
+
1268
+ parsed_prompts: List[PromptType] = []
1269
+ for i in range(num_requests):
1270
+ item: PromptType
1271
+
1272
+ if prompts is not None:
1273
+ item = TextPrompt(prompt=prompts[i])
1274
+ elif prompt_token_ids is not None:
1275
+ item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1276
+ else:
1277
+ raise AssertionError
1278
+
1279
+ parsed_prompts.append(item)
1280
+
1281
+ return parsed_prompts
1282
+
1283
+ def _validate_and_add_requests(
1284
+ self,
1285
+ prompts: Union[PromptType, Sequence[PromptType]],
1286
+ params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
1287
+ Sequence[PoolingParams]],
1288
+ lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1289
+ prompt_adapter_request: Optional[PromptAdapterRequest],
1290
+ guided_options: Optional[GuidedDecodingRequest] = None,
1291
+ priority: Optional[List[int]] = None,
1292
+ ) -> None:
1293
+ if guided_options is not None:
1294
+ warnings.warn(
1295
+ "guided_options_request is deprecated, use "
1296
+ "SamplingParams.guided_decoding instead",
1297
+ DeprecationWarning,
1298
+ stacklevel=2,
1299
+ )
1300
+
1301
+ if isinstance(prompts, (str, dict)):
1302
+ # Convert a single prompt to a list.
1303
+ prompts = [prompts]
1304
+
1305
+ num_requests = len(prompts)
1306
+ if isinstance(params, list) and len(params) != num_requests:
1307
+ raise ValueError("The lengths of prompts and params "
1308
+ "must be the same.")
1309
+ if isinstance(lora_request,
1310
+ list) and len(lora_request) != num_requests:
1311
+ raise ValueError("The lengths of prompts and lora_request "
1312
+ "must be the same.")
1313
+
1314
+ for sp in params if isinstance(params, list) else (params, ):
1315
+ if isinstance(sp, SamplingParams):
1316
+ self._add_guided_params(sp, guided_options)
1317
+
1318
+ # We only care about the final output
1319
+ sp.output_kind = RequestOutputKind.FINAL_ONLY
1320
+
1321
+ # Add requests to the engine.
1322
+ for i, prompt in enumerate(prompts):
1323
+ self._add_request(
1324
+ prompt,
1325
+ params[i] if isinstance(params, Sequence) else params,
1326
+ lora_request=lora_request[i] if isinstance(
1327
+ lora_request, Sequence) else lora_request,
1328
+ prompt_adapter_request=prompt_adapter_request,
1329
+ priority=priority[i] if priority else 0,
1330
+ )
1331
+
1332
+ def _add_request(
1333
+ self,
1334
+ prompt: PromptType,
1335
+ params: Union[SamplingParams, PoolingParams],
1336
+ lora_request: Optional[LoRARequest] = None,
1337
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1338
+ priority: int = 0,
1339
+ ) -> None:
1340
+ request_id = str(next(self.request_counter))
1341
+ self.llm_engine.add_request(
1342
+ request_id,
1343
+ prompt,
1344
+ params,
1345
+ lora_request=lora_request,
1346
+ prompt_adapter_request=prompt_adapter_request,
1347
+ priority=priority,
1348
+ )
1349
+
1350
+ def _add_guided_params(
1351
+ self,
1352
+ params: SamplingParams,
1353
+ guided_options: Optional[GuidedDecodingRequest] = None):
1354
+ if guided_options is None:
1355
+ return params
1356
+
1357
+ if params.guided_decoding is not None:
1358
+ raise ValueError("Cannot set both guided_options_request and"
1359
+ "params.guided_decoding.")
1360
+
1361
+ params.guided_decoding = GuidedDecodingParams(
1362
+ json=guided_options.guided_json,
1363
+ regex=guided_options.guided_regex,
1364
+ choice=guided_options.guided_choice,
1365
+ grammar=guided_options.guided_grammar,
1366
+ json_object=guided_options.guided_json_object,
1367
+ backend=guided_options.guided_decoding_backend,
1368
+ whitespace_pattern=guided_options.guided_whitespace_pattern)
1369
+ return params
1370
+
1371
+ def _run_engine(
1372
+ self, *, use_tqdm: bool
1373
+ ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1374
+ # Initialize tqdm.
1375
+ if use_tqdm:
1376
+ num_requests = self.llm_engine.get_num_unfinished_requests()
1377
+ pbar = tqdm(
1378
+ total=num_requests,
1379
+ desc="Processed prompts",
1380
+ dynamic_ncols=True,
1381
+ postfix=(f"est. speed input: {0:.2f} toks/s, "
1382
+ f"output: {0:.2f} toks/s"),
1383
+ )
1384
+
1385
+ # Run the engine.
1386
+ outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1387
+ total_in_toks = 0
1388
+ total_out_toks = 0
1389
+ while self.llm_engine.has_unfinished_requests():
1390
+ step_outputs = self.llm_engine.step()
1391
+ for output in step_outputs:
1392
+ if output.finished:
1393
+ outputs.append(output)
1394
+ if use_tqdm:
1395
+ if isinstance(output, RequestOutput):
1396
+ # Calculate tokens only for RequestOutput
1397
+ assert output.prompt_token_ids is not None
1398
+ total_in_toks += len(output.prompt_token_ids)
1399
+ in_spd = total_in_toks / pbar.format_dict["elapsed"]
1400
+ total_out_toks += sum(
1401
+ len(stp.token_ids) for stp in output.outputs)
1402
+ out_spd = (total_out_toks /
1403
+ pbar.format_dict["elapsed"])
1404
+ pbar.postfix = (
1405
+ f"est. speed input: {in_spd:.2f} toks/s, "
1406
+ f"output: {out_spd:.2f} toks/s")
1407
+ pbar.update(1)
1408
+
1409
+ if use_tqdm:
1410
+ pbar.close()
1411
+ # Sort the outputs by request ID.
1412
+ # This is necessary because some requests may be finished earlier than
1413
+ # its previous requests.
1414
+ return sorted(outputs, key=lambda x: int(x.request_id))
.venv/lib/python3.11/site-packages/vllm/entrypoints/logger.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ from vllm.logger import init_logger
6
+ from vllm.lora.request import LoRARequest
7
+ from vllm.pooling_params import PoolingParams
8
+ from vllm.prompt_adapter.request import PromptAdapterRequest
9
+ from vllm.sampling_params import BeamSearchParams, SamplingParams
10
+
11
+ logger = init_logger(__name__)
12
+
13
+
14
+ class RequestLogger:
15
+
16
+ def __init__(self, *, max_log_len: Optional[int]) -> None:
17
+ super().__init__()
18
+
19
+ self.max_log_len = max_log_len
20
+
21
+ def log_inputs(
22
+ self,
23
+ request_id: str,
24
+ prompt: Optional[str],
25
+ prompt_token_ids: Optional[List[int]],
26
+ params: Optional[Union[SamplingParams, PoolingParams,
27
+ BeamSearchParams]],
28
+ lora_request: Optional[LoRARequest],
29
+ prompt_adapter_request: Optional[PromptAdapterRequest],
30
+ ) -> None:
31
+ max_log_len = self.max_log_len
32
+ if max_log_len is not None:
33
+ if prompt is not None:
34
+ prompt = prompt[:max_log_len]
35
+
36
+ if prompt_token_ids is not None:
37
+ prompt_token_ids = prompt_token_ids[:max_log_len]
38
+
39
+ logger.info(
40
+ "Received request %s: prompt: %r, "
41
+ "params: %s, prompt_token_ids: %s, "
42
+ "lora_request: %s, prompt_adapter_request: %s.", request_id,
43
+ prompt, params, prompt_token_ids, lora_request,
44
+ prompt_adapter_request)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/api_server.cpython-311.pyc ADDED
Binary file (45.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/cli_args.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/protocol.cpython-311.pyc ADDED
Binary file (68.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/run_batch.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_models.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_pooling.cpython-311.pyc ADDED
Binary file (9.88 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_rerank.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_score.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-311.pyc ADDED
Binary file (5.98 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/api_server.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import atexit
5
+ import gc
6
+ import importlib
7
+ import inspect
8
+ import multiprocessing
9
+ import os
10
+ import re
11
+ import signal
12
+ import socket
13
+ import sys
14
+ import tempfile
15
+ import uuid
16
+ from argparse import Namespace
17
+ from contextlib import asynccontextmanager
18
+ from functools import partial
19
+ from http import HTTPStatus
20
+ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
21
+
22
+ import uvloop
23
+ from fastapi import APIRouter, FastAPI, HTTPException, Request
24
+ from fastapi.exceptions import RequestValidationError
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
27
+ from starlette.datastructures import State
28
+ from starlette.routing import Mount
29
+ from typing_extensions import assert_never
30
+
31
+ import vllm.envs as envs
32
+ from vllm.config import ModelConfig
33
+ from vllm.engine.arg_utils import AsyncEngineArgs
34
+ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
35
+ from vllm.engine.multiprocessing.client import MQLLMEngineClient
36
+ from vllm.engine.multiprocessing.engine import run_mp_engine
37
+ from vllm.engine.protocol import EngineClient
38
+ from vllm.entrypoints.chat_utils import load_chat_template
39
+ from vllm.entrypoints.launcher import serve_http
40
+ from vllm.entrypoints.logger import RequestLogger
41
+ from vllm.entrypoints.openai.cli_args import (make_arg_parser,
42
+ validate_parsed_serve_args)
43
+ # yapf conflicts with isort for this block
44
+ # yapf: disable
45
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
46
+ ChatCompletionResponse,
47
+ CompletionRequest,
48
+ CompletionResponse,
49
+ DetokenizeRequest,
50
+ DetokenizeResponse,
51
+ EmbeddingChatRequest,
52
+ EmbeddingCompletionRequest,
53
+ EmbeddingRequest,
54
+ EmbeddingResponse,
55
+ EmbeddingResponseData,
56
+ ErrorResponse,
57
+ LoadLoraAdapterRequest,
58
+ PoolingChatRequest,
59
+ PoolingCompletionRequest,
60
+ PoolingRequest, PoolingResponse,
61
+ RerankRequest, RerankResponse,
62
+ ScoreRequest, ScoreResponse,
63
+ TokenizeRequest,
64
+ TokenizeResponse,
65
+ UnloadLoraAdapterRequest)
66
+ from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
67
+ # yapf: enable
68
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
69
+ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
70
+ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
71
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
72
+ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
73
+ OpenAIServingModels)
74
+ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
75
+ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
76
+ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
77
+ from vllm.entrypoints.openai.serving_tokenization import (
78
+ OpenAIServingTokenization)
79
+ from vllm.entrypoints.openai.tool_parsers import ToolParserManager
80
+ from vllm.entrypoints.utils import with_cancellation
81
+ from vllm.logger import init_logger
82
+ from vllm.usage.usage_lib import UsageContext
83
+ from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
84
+ is_valid_ipv6_address, set_ulimit)
85
+ from vllm.version import __version__ as VLLM_VERSION
86
+
87
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
88
+
89
+ prometheus_multiproc_dir: tempfile.TemporaryDirectory
90
+
91
+ # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
92
+ logger = init_logger('vllm.entrypoints.openai.api_server')
93
+
94
+ _running_tasks: Set[asyncio.Task] = set()
95
+
96
+
97
+ @asynccontextmanager
98
+ async def lifespan(app: FastAPI):
99
+ try:
100
+ if app.state.log_stats:
101
+ engine_client: EngineClient = app.state.engine_client
102
+
103
+ async def _force_log():
104
+ while True:
105
+ await asyncio.sleep(10.)
106
+ await engine_client.do_log_stats()
107
+
108
+ task = asyncio.create_task(_force_log())
109
+ _running_tasks.add(task)
110
+ task.add_done_callback(_running_tasks.remove)
111
+ else:
112
+ task = None
113
+
114
+ # Mark the startup heap as static so that it's ignored by GC.
115
+ # Reduces pause times of oldest generation collections.
116
+ gc.collect()
117
+ gc.freeze()
118
+ try:
119
+ yield
120
+ finally:
121
+ if task is not None:
122
+ task.cancel()
123
+ finally:
124
+ # Ensure app state including engine ref is gc'd
125
+ del app.state
126
+
127
+
128
+ @asynccontextmanager
129
+ async def build_async_engine_client(
130
+ args: Namespace) -> AsyncIterator[EngineClient]:
131
+
132
+ # Context manager to handle engine_client lifecycle
133
+ # Ensures everything is shutdown and cleaned up on error/exit
134
+ engine_args = AsyncEngineArgs.from_cli_args(args)
135
+
136
+ async with build_async_engine_client_from_engine_args(
137
+ engine_args, args.disable_frontend_multiprocessing) as engine:
138
+ yield engine
139
+
140
+
141
+ @asynccontextmanager
142
+ async def build_async_engine_client_from_engine_args(
143
+ engine_args: AsyncEngineArgs,
144
+ disable_frontend_multiprocessing: bool = False,
145
+ ) -> AsyncIterator[EngineClient]:
146
+ """
147
+ Create EngineClient, either:
148
+ - in-process using the AsyncLLMEngine Directly
149
+ - multiprocess using AsyncLLMEngine RPC
150
+
151
+ Returns the Client or None if the creation failed.
152
+ """
153
+
154
+ # AsyncLLMEngine.
155
+ if (MQLLMEngineClient.is_unsupported_config(engine_args)
156
+ or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
157
+
158
+ engine_client: Optional[EngineClient] = None
159
+ try:
160
+ engine_client = AsyncLLMEngine.from_engine_args(
161
+ engine_args=engine_args,
162
+ usage_context=UsageContext.OPENAI_API_SERVER)
163
+ yield engine_client
164
+ finally:
165
+ if engine_client and hasattr(engine_client, "shutdown"):
166
+ engine_client.shutdown()
167
+
168
+ # MQLLMEngine.
169
+ else:
170
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
171
+ # Make TemporaryDirectory for prometheus multiprocessing
172
+ # Note: global TemporaryDirectory will be automatically
173
+ # cleaned up upon exit.
174
+ global prometheus_multiproc_dir
175
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
176
+ os.environ[
177
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
178
+ else:
179
+ logger.warning(
180
+ "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
181
+ "This directory must be wiped between vLLM runs or "
182
+ "you will find inaccurate metrics. Unset the variable "
183
+ "and vLLM will properly handle cleanup.")
184
+
185
+ # Select random path for IPC.
186
+ ipc_path = get_open_zmq_ipc_path()
187
+ logger.debug("Multiprocessing frontend to use %s for IPC Path.",
188
+ ipc_path)
189
+
190
+ # Start RPCServer in separate process (holds the LLMEngine).
191
+ # the current process might have CUDA context,
192
+ # so we need to spawn a new process
193
+ context = multiprocessing.get_context("spawn")
194
+
195
+ # The Process can raise an exception during startup, which may
196
+ # not actually result in an exitcode being reported. As a result
197
+ # we use a shared variable to communicate the information.
198
+ engine_alive = multiprocessing.Value('b', True, lock=False)
199
+ engine_process = context.Process(target=run_mp_engine,
200
+ args=(engine_args,
201
+ UsageContext.OPENAI_API_SERVER,
202
+ ipc_path, engine_alive))
203
+ engine_process.start()
204
+ engine_pid = engine_process.pid
205
+ assert engine_pid is not None, "Engine process failed to start."
206
+ logger.info("Started engine process with PID %d", engine_pid)
207
+
208
+ def _cleanup_ipc_path():
209
+ socket_path = ipc_path.replace("ipc://", "")
210
+ if os.path.exists(socket_path):
211
+ os.remove(socket_path)
212
+
213
+ # Ensure we clean up the local IPC socket file on exit.
214
+ atexit.register(_cleanup_ipc_path)
215
+
216
+ # Build RPCClient, which conforms to EngineClient Protocol.
217
+ engine_config = engine_args.create_engine_config()
218
+ build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
219
+ engine_pid)
220
+ mq_engine_client = await asyncio.get_running_loop().run_in_executor(
221
+ None, build_client)
222
+ try:
223
+ while True:
224
+ try:
225
+ await mq_engine_client.setup()
226
+ break
227
+ except TimeoutError:
228
+ if (not engine_process.is_alive()
229
+ or not engine_alive.value):
230
+ raise RuntimeError(
231
+ "Engine process failed to start. See stack "
232
+ "trace for the root cause.") from None
233
+
234
+ yield mq_engine_client # type: ignore[misc]
235
+ finally:
236
+ # Ensure rpc server process was terminated
237
+ engine_process.terminate()
238
+
239
+ # Close all open connections to the backend
240
+ mq_engine_client.close()
241
+
242
+ # Wait for engine process to join
243
+ engine_process.join(4)
244
+ if engine_process.exitcode is None:
245
+ # Kill if taking longer than 5 seconds to stop
246
+ engine_process.kill()
247
+
248
+ # Lazy import for prometheus multiprocessing.
249
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
250
+ # before prometheus_client is imported.
251
+ # See https://prometheus.github.io/client_python/multiprocess/
252
+ from prometheus_client import multiprocess
253
+ multiprocess.mark_process_dead(engine_process.pid)
254
+
255
+
256
+ router = APIRouter()
257
+
258
+
259
+ def mount_metrics(app: FastAPI):
260
+ # Lazy import for prometheus multiprocessing.
261
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
262
+ # before prometheus_client is imported.
263
+ # See https://prometheus.github.io/client_python/multiprocess/
264
+ from prometheus_client import (CollectorRegistry, make_asgi_app,
265
+ multiprocess)
266
+
267
+ prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
268
+ if prometheus_multiproc_dir_path is not None:
269
+ logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
270
+ prometheus_multiproc_dir_path)
271
+ registry = CollectorRegistry()
272
+ multiprocess.MultiProcessCollector(registry)
273
+
274
+ # Add prometheus asgi middleware to route /metrics requests
275
+ metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
276
+ else:
277
+ # Add prometheus asgi middleware to route /metrics requests
278
+ metrics_route = Mount("/metrics", make_asgi_app())
279
+
280
+ # Workaround for 307 Redirect for /metrics
281
+ metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
282
+ app.routes.append(metrics_route)
283
+
284
+
285
+ def base(request: Request) -> OpenAIServing:
286
+ # Reuse the existing instance
287
+ return tokenization(request)
288
+
289
+
290
+ def models(request: Request) -> OpenAIServingModels:
291
+ return request.app.state.openai_serving_models
292
+
293
+
294
+ def chat(request: Request) -> Optional[OpenAIServingChat]:
295
+ return request.app.state.openai_serving_chat
296
+
297
+
298
+ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
299
+ return request.app.state.openai_serving_completion
300
+
301
+
302
+ def pooling(request: Request) -> Optional[OpenAIServingPooling]:
303
+ return request.app.state.openai_serving_pooling
304
+
305
+
306
+ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
307
+ return request.app.state.openai_serving_embedding
308
+
309
+
310
+ def score(request: Request) -> Optional[OpenAIServingScores]:
311
+ return request.app.state.openai_serving_scores
312
+
313
+
314
+ def rerank(request: Request) -> Optional[JinaAIServingRerank]:
315
+ return request.app.state.jinaai_serving_reranking
316
+
317
+
318
+ def tokenization(request: Request) -> OpenAIServingTokenization:
319
+ return request.app.state.openai_serving_tokenization
320
+
321
+
322
+ def engine_client(request: Request) -> EngineClient:
323
+ return request.app.state.engine_client
324
+
325
+
326
+ @router.get("/health")
327
+ async def health(raw_request: Request) -> Response:
328
+ """Health check."""
329
+ await engine_client(raw_request).check_health()
330
+ return Response(status_code=200)
331
+
332
+
333
+ @router.api_route("/ping", methods=["GET", "POST"])
334
+ async def ping(raw_request: Request) -> Response:
335
+ """Ping check. Endpoint required for SageMaker"""
336
+ return await health(raw_request)
337
+
338
+
339
+ @router.post("/tokenize")
340
+ @with_cancellation
341
+ async def tokenize(request: TokenizeRequest, raw_request: Request):
342
+ handler = tokenization(raw_request)
343
+
344
+ generator = await handler.create_tokenize(request, raw_request)
345
+ if isinstance(generator, ErrorResponse):
346
+ return JSONResponse(content=generator.model_dump(),
347
+ status_code=generator.code)
348
+ elif isinstance(generator, TokenizeResponse):
349
+ return JSONResponse(content=generator.model_dump())
350
+
351
+ assert_never(generator)
352
+
353
+
354
+ @router.post("/detokenize")
355
+ @with_cancellation
356
+ async def detokenize(request: DetokenizeRequest, raw_request: Request):
357
+ handler = tokenization(raw_request)
358
+
359
+ generator = await handler.create_detokenize(request, raw_request)
360
+ if isinstance(generator, ErrorResponse):
361
+ return JSONResponse(content=generator.model_dump(),
362
+ status_code=generator.code)
363
+ elif isinstance(generator, DetokenizeResponse):
364
+ return JSONResponse(content=generator.model_dump())
365
+
366
+ assert_never(generator)
367
+
368
+
369
+ @router.get("/v1/models")
370
+ async def show_available_models(raw_request: Request):
371
+ handler = models(raw_request)
372
+
373
+ models_ = await handler.show_available_models()
374
+ return JSONResponse(content=models_.model_dump())
375
+
376
+
377
+ @router.get("/version")
378
+ async def show_version():
379
+ ver = {"version": VLLM_VERSION}
380
+ return JSONResponse(content=ver)
381
+
382
+
383
+ @router.post("/v1/chat/completions")
384
+ @with_cancellation
385
+ async def create_chat_completion(request: ChatCompletionRequest,
386
+ raw_request: Request):
387
+ handler = chat(raw_request)
388
+ if handler is None:
389
+ return base(raw_request).create_error_response(
390
+ message="The model does not support Chat Completions API")
391
+
392
+ generator = await handler.create_chat_completion(request, raw_request)
393
+
394
+ if isinstance(generator, ErrorResponse):
395
+ return JSONResponse(content=generator.model_dump(),
396
+ status_code=generator.code)
397
+
398
+ elif isinstance(generator, ChatCompletionResponse):
399
+ return JSONResponse(content=generator.model_dump())
400
+
401
+ return StreamingResponse(content=generator, media_type="text/event-stream")
402
+
403
+
404
+ @router.post("/v1/completions")
405
+ @with_cancellation
406
+ async def create_completion(request: CompletionRequest, raw_request: Request):
407
+ handler = completion(raw_request)
408
+ if handler is None:
409
+ return base(raw_request).create_error_response(
410
+ message="The model does not support Completions API")
411
+
412
+ generator = await handler.create_completion(request, raw_request)
413
+ if isinstance(generator, ErrorResponse):
414
+ return JSONResponse(content=generator.model_dump(),
415
+ status_code=generator.code)
416
+ elif isinstance(generator, CompletionResponse):
417
+ return JSONResponse(content=generator.model_dump())
418
+
419
+ return StreamingResponse(content=generator, media_type="text/event-stream")
420
+
421
+
422
+ @router.post("/v1/embeddings")
423
+ @with_cancellation
424
+ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
425
+ handler = embedding(raw_request)
426
+ if handler is None:
427
+ fallback_handler = pooling(raw_request)
428
+ if fallback_handler is None:
429
+ return base(raw_request).create_error_response(
430
+ message="The model does not support Embeddings API")
431
+
432
+ logger.warning(
433
+ "Embeddings API will become exclusive to embedding models "
434
+ "in a future release. To return the hidden states directly, "
435
+ "use the Pooling API (`/pooling`) instead.")
436
+
437
+ res = await fallback_handler.create_pooling(request, raw_request)
438
+
439
+ generator: Union[ErrorResponse, EmbeddingResponse]
440
+ if isinstance(res, PoolingResponse):
441
+ generator = EmbeddingResponse(
442
+ id=res.id,
443
+ object=res.object,
444
+ created=res.created,
445
+ model=res.model,
446
+ data=[
447
+ EmbeddingResponseData(
448
+ index=d.index,
449
+ embedding=d.data, # type: ignore
450
+ ) for d in res.data
451
+ ],
452
+ usage=res.usage,
453
+ )
454
+ else:
455
+ generator = res
456
+ else:
457
+ generator = await handler.create_embedding(request, raw_request)
458
+
459
+ if isinstance(generator, ErrorResponse):
460
+ return JSONResponse(content=generator.model_dump(),
461
+ status_code=generator.code)
462
+ elif isinstance(generator, EmbeddingResponse):
463
+ return JSONResponse(content=generator.model_dump())
464
+
465
+ assert_never(generator)
466
+
467
+
468
+ @router.post("/pooling")
469
+ @with_cancellation
470
+ async def create_pooling(request: PoolingRequest, raw_request: Request):
471
+ handler = pooling(raw_request)
472
+ if handler is None:
473
+ return base(raw_request).create_error_response(
474
+ message="The model does not support Pooling API")
475
+
476
+ generator = await handler.create_pooling(request, raw_request)
477
+ if isinstance(generator, ErrorResponse):
478
+ return JSONResponse(content=generator.model_dump(),
479
+ status_code=generator.code)
480
+ elif isinstance(generator, PoolingResponse):
481
+ return JSONResponse(content=generator.model_dump())
482
+
483
+ assert_never(generator)
484
+
485
+
486
+ @router.post("/score")
487
+ @with_cancellation
488
+ async def create_score(request: ScoreRequest, raw_request: Request):
489
+ handler = score(raw_request)
490
+ if handler is None:
491
+ return base(raw_request).create_error_response(
492
+ message="The model does not support Score API")
493
+
494
+ generator = await handler.create_score(request, raw_request)
495
+ if isinstance(generator, ErrorResponse):
496
+ return JSONResponse(content=generator.model_dump(),
497
+ status_code=generator.code)
498
+ elif isinstance(generator, ScoreResponse):
499
+ return JSONResponse(content=generator.model_dump())
500
+
501
+ assert_never(generator)
502
+
503
+
504
+ @router.post("/v1/score")
505
+ @with_cancellation
506
+ async def create_score_v1(request: ScoreRequest, raw_request: Request):
507
+ logger.warning(
508
+ "To indicate that Score API is not part of standard OpenAI API, we "
509
+ "have moved it to `/score`. Please update your client accordingly.")
510
+
511
+ return await create_score(request, raw_request)
512
+
513
+
514
+ @router.post("/rerank")
515
+ @with_cancellation
516
+ async def do_rerank(request: RerankRequest, raw_request: Request):
517
+ handler = rerank(raw_request)
518
+ if handler is None:
519
+ return base(raw_request).create_error_response(
520
+ message="The model does not support Rerank (Score) API")
521
+ generator = await handler.do_rerank(request, raw_request)
522
+ if isinstance(generator, ErrorResponse):
523
+ return JSONResponse(content=generator.model_dump(),
524
+ status_code=generator.code)
525
+ elif isinstance(generator, RerankResponse):
526
+ return JSONResponse(content=generator.model_dump())
527
+
528
+ assert_never(generator)
529
+
530
+
531
+ @router.post("/v1/rerank")
532
+ @with_cancellation
533
+ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
534
+ logger.warning_once(
535
+ "To indicate that the rerank API is not part of the standard OpenAI"
536
+ " API, we have located it at `/rerank`. Please update your client"
537
+ "accordingly. (Note: Conforms to JinaAI rerank API)")
538
+
539
+ return await do_rerank(request, raw_request)
540
+
541
+
542
+ @router.post("/v2/rerank")
543
+ @with_cancellation
544
+ async def do_rerank_v2(request: RerankRequest, raw_request: Request):
545
+ return await do_rerank(request, raw_request)
546
+
547
+
548
+ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
549
+ "generate": {
550
+ "messages": (ChatCompletionRequest, create_chat_completion),
551
+ "default": (CompletionRequest, create_completion),
552
+ },
553
+ "embed": {
554
+ "messages": (EmbeddingChatRequest, create_embedding),
555
+ "default": (EmbeddingCompletionRequest, create_embedding),
556
+ },
557
+ "score": {
558
+ "default": (RerankRequest, do_rerank)
559
+ },
560
+ "rerank": {
561
+ "default": (RerankRequest, do_rerank)
562
+ },
563
+ "reward": {
564
+ "messages": (PoolingChatRequest, create_pooling),
565
+ "default": (PoolingCompletionRequest, create_pooling),
566
+ },
567
+ "classify": {
568
+ "messages": (PoolingChatRequest, create_pooling),
569
+ "default": (PoolingCompletionRequest, create_pooling),
570
+ },
571
+ }
572
+
573
+ if envs.VLLM_SERVER_DEV_MODE:
574
+
575
+ @router.post("/reset_prefix_cache")
576
+ async def reset_prefix_cache(raw_request: Request):
577
+ """
578
+ Reset the prefix cache. Note that we currently do not check if the
579
+ prefix cache is successfully reset in the API server.
580
+ """
581
+ logger.info("Resetting prefix cache...")
582
+ await engine_client(raw_request).reset_prefix_cache()
583
+ return Response(status_code=200)
584
+
585
+
586
+ @router.post("/invocations")
587
+ async def invocations(raw_request: Request):
588
+ """
589
+ For SageMaker, routes requests to other handlers based on model `task`.
590
+ """
591
+ body = await raw_request.json()
592
+ task = raw_request.app.state.task
593
+
594
+ if task not in TASK_HANDLERS:
595
+ raise HTTPException(
596
+ status_code=400,
597
+ detail=f"Unsupported task: '{task}' for '/invocations'. "
598
+ f"Expected one of {set(TASK_HANDLERS.keys())}")
599
+
600
+ handler_config = TASK_HANDLERS[task]
601
+ if "messages" in body:
602
+ request_model, handler = handler_config["messages"]
603
+ else:
604
+ request_model, handler = handler_config["default"]
605
+
606
+ # this is required since we lose the FastAPI automatic casting
607
+ request = request_model.model_validate(body)
608
+ return await handler(request, raw_request)
609
+
610
+
611
+ if envs.VLLM_TORCH_PROFILER_DIR:
612
+ logger.warning(
613
+ "Torch Profiler is enabled in the API server. This should ONLY be "
614
+ "used for local development!")
615
+
616
+ @router.post("/start_profile")
617
+ async def start_profile(raw_request: Request):
618
+ logger.info("Starting profiler...")
619
+ await engine_client(raw_request).start_profile()
620
+ logger.info("Profiler started.")
621
+ return Response(status_code=200)
622
+
623
+ @router.post("/stop_profile")
624
+ async def stop_profile(raw_request: Request):
625
+ logger.info("Stopping profiler...")
626
+ await engine_client(raw_request).stop_profile()
627
+ logger.info("Profiler stopped.")
628
+ return Response(status_code=200)
629
+
630
+
631
+ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
632
+ logger.warning(
633
+ "Lora dynamic loading & unloading is enabled in the API server. "
634
+ "This should ONLY be used for local development!")
635
+
636
+ @router.post("/v1/load_lora_adapter")
637
+ async def load_lora_adapter(request: LoadLoraAdapterRequest,
638
+ raw_request: Request):
639
+ handler = models(raw_request)
640
+ response = await handler.load_lora_adapter(request)
641
+ if isinstance(response, ErrorResponse):
642
+ return JSONResponse(content=response.model_dump(),
643
+ status_code=response.code)
644
+
645
+ return Response(status_code=200, content=response)
646
+
647
+ @router.post("/v1/unload_lora_adapter")
648
+ async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
649
+ raw_request: Request):
650
+ handler = models(raw_request)
651
+ response = await handler.unload_lora_adapter(request)
652
+ if isinstance(response, ErrorResponse):
653
+ return JSONResponse(content=response.model_dump(),
654
+ status_code=response.code)
655
+
656
+ return Response(status_code=200, content=response)
657
+
658
+
659
+ def build_app(args: Namespace) -> FastAPI:
660
+ if args.disable_fastapi_docs:
661
+ app = FastAPI(openapi_url=None,
662
+ docs_url=None,
663
+ redoc_url=None,
664
+ lifespan=lifespan)
665
+ else:
666
+ app = FastAPI(lifespan=lifespan)
667
+ app.include_router(router)
668
+ app.root_path = args.root_path
669
+
670
+ mount_metrics(app)
671
+
672
+ app.add_middleware(
673
+ CORSMiddleware,
674
+ allow_origins=args.allowed_origins,
675
+ allow_credentials=args.allow_credentials,
676
+ allow_methods=args.allowed_methods,
677
+ allow_headers=args.allowed_headers,
678
+ )
679
+
680
+ @app.exception_handler(RequestValidationError)
681
+ async def validation_exception_handler(_, exc):
682
+ err = ErrorResponse(message=str(exc),
683
+ type="BadRequestError",
684
+ code=HTTPStatus.BAD_REQUEST)
685
+ return JSONResponse(err.model_dump(),
686
+ status_code=HTTPStatus.BAD_REQUEST)
687
+
688
+ if token := envs.VLLM_API_KEY or args.api_key:
689
+
690
+ @app.middleware("http")
691
+ async def authentication(request: Request, call_next):
692
+ if request.method == "OPTIONS":
693
+ return await call_next(request)
694
+ url_path = request.url.path
695
+ if app.root_path and url_path.startswith(app.root_path):
696
+ url_path = url_path[len(app.root_path):]
697
+ if not url_path.startswith("/v1"):
698
+ return await call_next(request)
699
+ if request.headers.get("Authorization") != "Bearer " + token:
700
+ return JSONResponse(content={"error": "Unauthorized"},
701
+ status_code=401)
702
+ return await call_next(request)
703
+
704
+ if args.enable_request_id_headers:
705
+ logger.warning(
706
+ "CAUTION: Enabling X-Request-Id headers in the API Server. "
707
+ "This can harm performance at high QPS.")
708
+
709
+ @app.middleware("http")
710
+ async def add_request_id(request: Request, call_next):
711
+ request_id = request.headers.get(
712
+ "X-Request-Id") or uuid.uuid4().hex
713
+ response = await call_next(request)
714
+ response.headers["X-Request-Id"] = request_id
715
+ return response
716
+
717
+ for middleware in args.middleware:
718
+ module_path, object_name = middleware.rsplit(".", 1)
719
+ imported = getattr(importlib.import_module(module_path), object_name)
720
+ if inspect.isclass(imported):
721
+ app.add_middleware(imported) # type: ignore[arg-type]
722
+ elif inspect.iscoroutinefunction(imported):
723
+ app.middleware("http")(imported)
724
+ else:
725
+ raise ValueError(f"Invalid middleware {middleware}. "
726
+ f"Must be a function or a class.")
727
+
728
+ return app
729
+
730
+
731
+ async def init_app_state(
732
+ engine_client: EngineClient,
733
+ model_config: ModelConfig,
734
+ state: State,
735
+ args: Namespace,
736
+ ) -> None:
737
+ if args.served_model_name is not None:
738
+ served_model_names = args.served_model_name
739
+ else:
740
+ served_model_names = [args.model]
741
+
742
+ if args.disable_log_requests:
743
+ request_logger = None
744
+ else:
745
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
746
+
747
+ base_model_paths = [
748
+ BaseModelPath(name=name, model_path=args.model)
749
+ for name in served_model_names
750
+ ]
751
+
752
+ state.engine_client = engine_client
753
+ state.log_stats = not args.disable_log_stats
754
+
755
+ resolved_chat_template = load_chat_template(args.chat_template)
756
+ logger.info("Using supplied chat template:\n%s", resolved_chat_template)
757
+
758
+ state.openai_serving_models = OpenAIServingModels(
759
+ engine_client=engine_client,
760
+ model_config=model_config,
761
+ base_model_paths=base_model_paths,
762
+ lora_modules=args.lora_modules,
763
+ prompt_adapters=args.prompt_adapters,
764
+ )
765
+ await state.openai_serving_models.init_static_loras()
766
+ state.openai_serving_chat = OpenAIServingChat(
767
+ engine_client,
768
+ model_config,
769
+ state.openai_serving_models,
770
+ args.response_role,
771
+ request_logger=request_logger,
772
+ chat_template=resolved_chat_template,
773
+ chat_template_content_format=args.chat_template_content_format,
774
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
775
+ enable_auto_tools=args.enable_auto_tool_choice,
776
+ tool_parser=args.tool_call_parser,
777
+ enable_reasoning=args.enable_reasoning,
778
+ reasoning_parser=args.reasoning_parser,
779
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
780
+ ) if model_config.runner_type == "generate" else None
781
+ state.openai_serving_completion = OpenAIServingCompletion(
782
+ engine_client,
783
+ model_config,
784
+ state.openai_serving_models,
785
+ request_logger=request_logger,
786
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
787
+ ) if model_config.runner_type == "generate" else None
788
+ state.openai_serving_pooling = OpenAIServingPooling(
789
+ engine_client,
790
+ model_config,
791
+ state.openai_serving_models,
792
+ request_logger=request_logger,
793
+ chat_template=resolved_chat_template,
794
+ chat_template_content_format=args.chat_template_content_format,
795
+ ) if model_config.runner_type == "pooling" else None
796
+ state.openai_serving_embedding = OpenAIServingEmbedding(
797
+ engine_client,
798
+ model_config,
799
+ state.openai_serving_models,
800
+ request_logger=request_logger,
801
+ chat_template=resolved_chat_template,
802
+ chat_template_content_format=args.chat_template_content_format,
803
+ ) if model_config.task == "embed" else None
804
+ state.openai_serving_scores = OpenAIServingScores(
805
+ engine_client,
806
+ model_config,
807
+ state.openai_serving_models,
808
+ request_logger=request_logger
809
+ ) if model_config.task == "score" else None
810
+ state.jinaai_serving_reranking = JinaAIServingRerank(
811
+ engine_client,
812
+ model_config,
813
+ state.openai_serving_models,
814
+ request_logger=request_logger
815
+ ) if model_config.task == "score" else None
816
+ state.openai_serving_tokenization = OpenAIServingTokenization(
817
+ engine_client,
818
+ model_config,
819
+ state.openai_serving_models,
820
+ request_logger=request_logger,
821
+ chat_template=resolved_chat_template,
822
+ chat_template_content_format=args.chat_template_content_format,
823
+ )
824
+ state.task = model_config.task
825
+
826
+
827
+ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
828
+ family = socket.AF_INET
829
+ if is_valid_ipv6_address(addr[0]):
830
+ family = socket.AF_INET6
831
+
832
+ sock = socket.socket(family=family, type=socket.SOCK_STREAM)
833
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
834
+ sock.bind(addr)
835
+
836
+ return sock
837
+
838
+
839
+ async def run_server(args, **uvicorn_kwargs) -> None:
840
+ logger.info("vLLM API server version %s", VLLM_VERSION)
841
+ logger.info("args: %s", args)
842
+
843
+ if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
844
+ ToolParserManager.import_tool_parser(args.tool_parser_plugin)
845
+
846
+ valid_tool_parses = ToolParserManager.tool_parsers.keys()
847
+ if args.enable_auto_tool_choice \
848
+ and args.tool_call_parser not in valid_tool_parses:
849
+ raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
850
+ f"(chose from {{ {','.join(valid_tool_parses)} }})")
851
+
852
+ valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
853
+ if args.enable_reasoning \
854
+ and args.reasoning_parser not in valid_reasoning_parses:
855
+ raise KeyError(
856
+ f"invalid reasoning parser: {args.reasoning_parser} "
857
+ f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
858
+
859
+ # workaround to make sure that we bind the port before the engine is set up.
860
+ # This avoids race conditions with ray.
861
+ # see https://github.com/vllm-project/vllm/issues/8204
862
+ sock_addr = (args.host or "", args.port)
863
+ sock = create_server_socket(sock_addr)
864
+
865
+ # workaround to avoid footguns where uvicorn drops requests with too
866
+ # many concurrent requests active
867
+ set_ulimit()
868
+
869
+ def signal_handler(*_) -> None:
870
+ # Interrupt server on sigterm while initializing
871
+ raise KeyboardInterrupt("terminated")
872
+
873
+ signal.signal(signal.SIGTERM, signal_handler)
874
+
875
+ async with build_async_engine_client(args) as engine_client:
876
+ app = build_app(args)
877
+
878
+ model_config = await engine_client.get_model_config()
879
+ await init_app_state(engine_client, model_config, app.state, args)
880
+
881
+ shutdown_task = await serve_http(
882
+ app,
883
+ host=args.host,
884
+ port=args.port,
885
+ log_level=args.uvicorn_log_level,
886
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
887
+ ssl_keyfile=args.ssl_keyfile,
888
+ ssl_certfile=args.ssl_certfile,
889
+ ssl_ca_certs=args.ssl_ca_certs,
890
+ ssl_cert_reqs=args.ssl_cert_reqs,
891
+ # Workaround to work on macOS
892
+ fd=sock.fileno() if sys.platform.startswith("darwin") else None,
893
+ **uvicorn_kwargs,
894
+ )
895
+
896
+ # NB: Await server shutdown only after the backend context is exited
897
+ await shutdown_task
898
+
899
+ sock.close()
900
+
901
+
902
+ if __name__ == "__main__":
903
+ # NOTE(simon):
904
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
905
+ parser = FlexibleArgumentParser(
906
+ description="vLLM OpenAI-Compatible RESTful API server.")
907
+ parser = make_arg_parser(parser)
908
+ args = parser.parse_args()
909
+ validate_parsed_serve_args(args)
910
+
911
+ uvloop.run(run_server(args))
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/cli_args.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ This file contains the command line arguments for the vLLM's
4
+ OpenAI-compatible server. It is kept in a separate file for documentation
5
+ purposes.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import ssl
11
+ from typing import List, Optional, Sequence, Union, get_args
12
+
13
+ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
14
+ from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
15
+ validate_chat_template)
16
+ from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
17
+ from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
18
+ PromptAdapterPath)
19
+ from vllm.entrypoints.openai.tool_parsers import ToolParserManager
20
+ from vllm.utils import FlexibleArgumentParser
21
+
22
+
23
+ class LoRAParserAction(argparse.Action):
24
+
25
+ def __call__(
26
+ self,
27
+ parser: argparse.ArgumentParser,
28
+ namespace: argparse.Namespace,
29
+ values: Optional[Union[str, Sequence[str]]],
30
+ option_string: Optional[str] = None,
31
+ ):
32
+ if values is None:
33
+ values = []
34
+ if isinstance(values, str):
35
+ raise TypeError("Expected values to be a list")
36
+
37
+ lora_list: List[LoRAModulePath] = []
38
+ for item in values:
39
+ if item in [None, '']: # Skip if item is None or empty string
40
+ continue
41
+ if '=' in item and ',' not in item: # Old format: name=path
42
+ name, path = item.split('=')
43
+ lora_list.append(LoRAModulePath(name, path))
44
+ else: # Assume JSON format
45
+ try:
46
+ lora_dict = json.loads(item)
47
+ lora = LoRAModulePath(**lora_dict)
48
+ lora_list.append(lora)
49
+ except json.JSONDecodeError:
50
+ parser.error(
51
+ f"Invalid JSON format for --lora-modules: {item}")
52
+ except TypeError as e:
53
+ parser.error(
54
+ f"Invalid fields for --lora-modules: {item} - {str(e)}"
55
+ )
56
+ setattr(namespace, self.dest, lora_list)
57
+
58
+
59
+ class PromptAdapterParserAction(argparse.Action):
60
+
61
+ def __call__(
62
+ self,
63
+ parser: argparse.ArgumentParser,
64
+ namespace: argparse.Namespace,
65
+ values: Optional[Union[str, Sequence[str]]],
66
+ option_string: Optional[str] = None,
67
+ ):
68
+ if values is None:
69
+ values = []
70
+ if isinstance(values, str):
71
+ raise TypeError("Expected values to be a list")
72
+
73
+ adapter_list: List[PromptAdapterPath] = []
74
+ for item in values:
75
+ name, path = item.split('=')
76
+ adapter_list.append(PromptAdapterPath(name, path))
77
+ setattr(namespace, self.dest, adapter_list)
78
+
79
+
80
+ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
81
+ parser.add_argument("--host",
82
+ type=nullable_str,
83
+ default=None,
84
+ help="Host name.")
85
+ parser.add_argument("--port", type=int, default=8000, help="Port number.")
86
+ parser.add_argument(
87
+ "--uvicorn-log-level",
88
+ type=str,
89
+ default="info",
90
+ choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
91
+ help="Log level for uvicorn.")
92
+ parser.add_argument("--allow-credentials",
93
+ action="store_true",
94
+ help="Allow credentials.")
95
+ parser.add_argument("--allowed-origins",
96
+ type=json.loads,
97
+ default=["*"],
98
+ help="Allowed origins.")
99
+ parser.add_argument("--allowed-methods",
100
+ type=json.loads,
101
+ default=["*"],
102
+ help="Allowed methods.")
103
+ parser.add_argument("--allowed-headers",
104
+ type=json.loads,
105
+ default=["*"],
106
+ help="Allowed headers.")
107
+ parser.add_argument("--api-key",
108
+ type=nullable_str,
109
+ default=None,
110
+ help="If provided, the server will require this key "
111
+ "to be presented in the header.")
112
+ parser.add_argument(
113
+ "--lora-modules",
114
+ type=nullable_str,
115
+ default=None,
116
+ nargs='+',
117
+ action=LoRAParserAction,
118
+ help="LoRA module configurations in either 'name=path' format"
119
+ "or JSON format. "
120
+ "Example (old format): ``'name=path'`` "
121
+ "Example (new format): "
122
+ "``{\"name\": \"name\", \"path\": \"lora_path\", "
123
+ "\"base_model_name\": \"id\"}``")
124
+ parser.add_argument(
125
+ "--prompt-adapters",
126
+ type=nullable_str,
127
+ default=None,
128
+ nargs='+',
129
+ action=PromptAdapterParserAction,
130
+ help="Prompt adapter configurations in the format name=path. "
131
+ "Multiple adapters can be specified.")
132
+ parser.add_argument("--chat-template",
133
+ type=nullable_str,
134
+ default=None,
135
+ help="The file path to the chat template, "
136
+ "or the template in single-line form "
137
+ "for the specified model.")
138
+ parser.add_argument(
139
+ '--chat-template-content-format',
140
+ type=str,
141
+ default="auto",
142
+ choices=get_args(ChatTemplateContentFormatOption),
143
+ help='The format to render message content within a chat template.'
144
+ '\n\n'
145
+ '* "string" will render the content as a string. '
146
+ 'Example: ``"Hello World"``\n'
147
+ '* "openai" will render the content as a list of dictionaries, '
148
+ 'similar to OpenAI schema. '
149
+ 'Example: ``[{"type": "text", "text": "Hello world!"}]``')
150
+ parser.add_argument("--response-role",
151
+ type=nullable_str,
152
+ default="assistant",
153
+ help="The role name to return if "
154
+ "``request.add_generation_prompt=true``.")
155
+ parser.add_argument("--ssl-keyfile",
156
+ type=nullable_str,
157
+ default=None,
158
+ help="The file path to the SSL key file.")
159
+ parser.add_argument("--ssl-certfile",
160
+ type=nullable_str,
161
+ default=None,
162
+ help="The file path to the SSL cert file.")
163
+ parser.add_argument("--ssl-ca-certs",
164
+ type=nullable_str,
165
+ default=None,
166
+ help="The CA certificates file.")
167
+ parser.add_argument(
168
+ "--ssl-cert-reqs",
169
+ type=int,
170
+ default=int(ssl.CERT_NONE),
171
+ help="Whether client certificate is required (see stdlib ssl module's)."
172
+ )
173
+ parser.add_argument(
174
+ "--root-path",
175
+ type=nullable_str,
176
+ default=None,
177
+ help="FastAPI root_path when app is behind a path based routing proxy."
178
+ )
179
+ parser.add_argument(
180
+ "--middleware",
181
+ type=nullable_str,
182
+ action="append",
183
+ default=[],
184
+ help="Additional ASGI middleware to apply to the app. "
185
+ "We accept multiple --middleware arguments. "
186
+ "The value should be an import path. "
187
+ "If a function is provided, vLLM will add it to the server "
188
+ "using ``@app.middleware('http')``. "
189
+ "If a class is provided, vLLM will add it to the server "
190
+ "using ``app.add_middleware()``. ")
191
+ parser.add_argument(
192
+ "--return-tokens-as-token-ids",
193
+ action="store_true",
194
+ help="When ``--max-logprobs`` is specified, represents single tokens "
195
+ " as strings of the form 'token_id:{token_id}' so that tokens "
196
+ "that are not JSON-encodable can be identified.")
197
+ parser.add_argument(
198
+ "--disable-frontend-multiprocessing",
199
+ action="store_true",
200
+ help="If specified, will run the OpenAI frontend server in the same "
201
+ "process as the model serving engine.")
202
+ parser.add_argument(
203
+ "--enable-request-id-headers",
204
+ action="store_true",
205
+ help="If specified, API server will add X-Request-Id header to "
206
+ "responses. Caution: this hurts performance at high QPS.")
207
+ parser.add_argument(
208
+ "--enable-auto-tool-choice",
209
+ action="store_true",
210
+ default=False,
211
+ help="Enable auto tool choice for supported models. Use "
212
+ "``--tool-call-parser`` to specify which parser to use.")
213
+ parser.add_argument(
214
+ "--enable-reasoning",
215
+ action="store_true",
216
+ default=False,
217
+ help="Whether to enable reasoning_content for the model. "
218
+ "If enabled, the model will be able to generate reasoning content.")
219
+
220
+ valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
221
+ parser.add_argument(
222
+ "--reasoning-parser",
223
+ type=str,
224
+ metavar="{" + ",".join(valid_reasoning_parsers) + "}",
225
+ default=None,
226
+ help=
227
+ "Select the reasoning parser depending on the model that you're using."
228
+ " This is used to parse the reasoning content into OpenAI API "
229
+ "format. Required for ``--enable-reasoning``.")
230
+
231
+ valid_tool_parsers = ToolParserManager.tool_parsers.keys()
232
+ parser.add_argument(
233
+ "--tool-call-parser",
234
+ type=str,
235
+ metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
236
+ "--tool-parser-plugin",
237
+ default=None,
238
+ help=
239
+ "Select the tool call parser depending on the model that you're using."
240
+ " This is used to parse the model-generated tool call into OpenAI API "
241
+ "format. Required for ``--enable-auto-tool-choice``.")
242
+
243
+ parser.add_argument(
244
+ "--tool-parser-plugin",
245
+ type=str,
246
+ default="",
247
+ help=
248
+ "Special the tool parser plugin write to parse the model-generated tool"
249
+ " into OpenAI API format, the name register in this plugin can be used "
250
+ "in ``--tool-call-parser``.")
251
+
252
+ parser = AsyncEngineArgs.add_cli_args(parser)
253
+
254
+ parser.add_argument('--max-log-len',
255
+ type=int,
256
+ default=None,
257
+ help='Max number of prompt characters or prompt '
258
+ 'ID numbers being printed in log.'
259
+ '\n\nDefault: Unlimited')
260
+
261
+ parser.add_argument(
262
+ "--disable-fastapi-docs",
263
+ action='store_true',
264
+ default=False,
265
+ help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."
266
+ )
267
+ parser.add_argument(
268
+ "--enable-prompt-tokens-details",
269
+ action='store_true',
270
+ default=False,
271
+ help="If set to True, enable prompt_tokens_details in usage.")
272
+
273
+ return parser
274
+
275
+
276
+ def validate_parsed_serve_args(args: argparse.Namespace):
277
+ """Quick checks for model serve args that raise prior to loading."""
278
+ if hasattr(args, "subparser") and args.subparser != "serve":
279
+ return
280
+
281
+ # Ensure that the chat template is valid; raises if it likely isn't
282
+ validate_chat_template(args.chat_template)
283
+
284
+ # Enable auto tool needs a tool call parser to be valid
285
+ if args.enable_auto_tool_choice and not args.tool_call_parser:
286
+ raise TypeError("Error: --enable-auto-tool-choice requires "
287
+ "--tool-call-parser")
288
+
289
+ # Enable reasoning needs a reasoning parser to be valid
290
+ if args.enable_reasoning and not args.reasoning_parser:
291
+ raise TypeError("Error: --enable-reasoning requires "
292
+ "--reasoning-parser")
293
+
294
+ # Ref https://api-docs.deepseek.com/guides/reasoning_model
295
+ # tool call and reasoning cannot be enabled at the same time.
296
+ if args.enable_auto_tool_choice and args.enable_reasoning:
297
+ raise TypeError(
298
+ "Error: --enable-auto-tool-choice and "
299
+ "--enable-reasoning cannot be enabled at the same time")
300
+
301
+
302
+ def create_parser_for_docs() -> FlexibleArgumentParser:
303
+ parser_for_docs = FlexibleArgumentParser(
304
+ prog="-m vllm.entrypoints.openai.api_server")
305
+ return make_arg_parser(parser_for_docs)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/logits_processors.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import lru_cache, partial
4
+ from typing import Dict, FrozenSet, Iterable, List, Optional, Union
5
+
6
+ import torch
7
+
8
+ from vllm.sampling_params import LogitsProcessor
9
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
10
+
11
+
12
+ class AllowedTokenIdsLogitsProcessor:
13
+ """Logits processor for constraining generated tokens to a
14
+ specific set of token ids."""
15
+
16
+ def __init__(self, allowed_ids: Iterable[int]):
17
+ self.allowed_ids: Optional[List[int]] = list(allowed_ids)
18
+ self.mask: Optional[torch.Tensor] = None
19
+
20
+ def __call__(self, token_ids: List[int],
21
+ logits: torch.Tensor) -> torch.Tensor:
22
+ if self.mask is None:
23
+ self.mask = torch.ones((logits.shape[-1], ),
24
+ dtype=torch.bool,
25
+ device=logits.device)
26
+ self.mask[self.allowed_ids] = False
27
+ self.allowed_ids = None
28
+ logits.masked_fill_(self.mask, float("-inf"))
29
+ return logits
30
+
31
+
32
+ @lru_cache(maxsize=32)
33
+ def _get_allowed_token_ids_logits_processor(
34
+ allowed_token_ids: FrozenSet[int],
35
+ vocab_size: int,
36
+ ) -> LogitsProcessor:
37
+ if not allowed_token_ids:
38
+ raise ValueError("Empty allowed_token_ids provided")
39
+ if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
40
+ raise ValueError("allowed_token_ids contains "
41
+ "out-of-vocab token id")
42
+ return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
43
+
44
+
45
+ def logit_bias_logits_processor(
46
+ logit_bias: Dict[int, float],
47
+ token_ids: List[int],
48
+ logits: torch.Tensor,
49
+ ) -> torch.Tensor:
50
+ for token_id, bias in logit_bias.items():
51
+ logits[token_id] += bias
52
+ return logits
53
+
54
+
55
+ def get_logits_processors(
56
+ logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
57
+ allowed_token_ids: Optional[List[int]],
58
+ tokenizer: AnyTokenizer,
59
+ ) -> List[LogitsProcessor]:
60
+ logits_processors: List[LogitsProcessor] = []
61
+ if logit_bias:
62
+ try:
63
+ # Convert token_id to integer
64
+ # Clamp the bias between -100 and 100 per OpenAI API spec
65
+ clamped_logit_bias: Dict[int, float] = {
66
+ int(token_id): min(100.0, max(-100.0, bias))
67
+ for token_id, bias in logit_bias.items()
68
+ }
69
+ except ValueError as exc:
70
+ raise ValueError(
71
+ "Found token_id in logit_bias that is not "
72
+ "an integer or string representing an integer") from exc
73
+
74
+ # Check if token_id is within the vocab size
75
+ for token_id, bias in clamped_logit_bias.items():
76
+ if token_id < 0 or token_id >= len(tokenizer):
77
+ raise ValueError(f"token_id {token_id} in logit_bias contains "
78
+ "out-of-vocab token id")
79
+
80
+ logits_processors.append(
81
+ partial(logit_bias_logits_processor, clamped_logit_bias))
82
+
83
+ if allowed_token_ids is not None:
84
+ logits_processors.append(
85
+ _get_allowed_token_ids_logits_processor(
86
+ frozenset(allowed_token_ids), len(tokenizer)))
87
+
88
+ return logits_processors
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/protocol.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
5
+ import re
6
+ import time
7
+ from argparse import Namespace
8
+ from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
9
+
10
+ import torch
11
+ from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
12
+ ValidationInfo, field_validator, model_validator)
13
+ from typing_extensions import Annotated
14
+
15
+ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
16
+ from vllm.logger import init_logger
17
+ from vllm.pooling_params import PoolingParams
18
+ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
19
+ RequestOutputKind, SamplingParams)
20
+ from vllm.sequence import Logprob
21
+ from vllm.utils import random_uuid, resolve_obj_by_qualname
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ # torch is mocked during docs generation,
26
+ # so we have to provide the values as literals
27
+ _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
28
+ _LONG_INFO: Union["torch.iinfo", Namespace]
29
+
30
+ try:
31
+ from sphinx.ext.autodoc.mock import _MockModule
32
+
33
+ if isinstance(torch, _MockModule):
34
+ _LONG_INFO = _MOCK_LONG_INFO
35
+ else:
36
+ _LONG_INFO = torch.iinfo(torch.long)
37
+ except ModuleNotFoundError:
38
+ _LONG_INFO = torch.iinfo(torch.long)
39
+
40
+ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
41
+ assert _LONG_INFO.max == _MOCK_LONG_INFO.max
42
+
43
+
44
+ class OpenAIBaseModel(BaseModel):
45
+ # OpenAI API does allow extra fields
46
+ model_config = ConfigDict(extra="allow")
47
+
48
+ # Cache class field names
49
+ field_names: ClassVar[Optional[Set[str]]] = None
50
+
51
+ @model_validator(mode="wrap")
52
+ @classmethod
53
+ def __log_extra_fields__(cls, data, handler):
54
+ result = handler(data)
55
+ if not isinstance(data, dict):
56
+ return result
57
+ field_names = cls.field_names
58
+ if field_names is None:
59
+ # Get all class field names and their potential aliases
60
+ field_names = set()
61
+ for field_name, field in cls.model_fields.items():
62
+ field_names.add(field_name)
63
+ if alias := getattr(field, 'alias', None):
64
+ field_names.add(alias)
65
+ cls.field_names = field_names
66
+
67
+ # Compare against both field names and aliases
68
+ if any(k not in field_names for k in data):
69
+ logger.warning(
70
+ "The following fields were present in the request "
71
+ "but ignored: %s",
72
+ data.keys() - field_names)
73
+ return result
74
+
75
+
76
+ class ErrorResponse(OpenAIBaseModel):
77
+ object: str = "error"
78
+ message: str
79
+ type: str
80
+ param: Optional[str] = None
81
+ code: int
82
+
83
+
84
+ class ModelPermission(OpenAIBaseModel):
85
+ id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
86
+ object: str = "model_permission"
87
+ created: int = Field(default_factory=lambda: int(time.time()))
88
+ allow_create_engine: bool = False
89
+ allow_sampling: bool = True
90
+ allow_logprobs: bool = True
91
+ allow_search_indices: bool = False
92
+ allow_view: bool = True
93
+ allow_fine_tuning: bool = False
94
+ organization: str = "*"
95
+ group: Optional[str] = None
96
+ is_blocking: bool = False
97
+
98
+
99
+ class ModelCard(OpenAIBaseModel):
100
+ id: str
101
+ object: str = "model"
102
+ created: int = Field(default_factory=lambda: int(time.time()))
103
+ owned_by: str = "vllm"
104
+ root: Optional[str] = None
105
+ parent: Optional[str] = None
106
+ max_model_len: Optional[int] = None
107
+ permission: List[ModelPermission] = Field(default_factory=list)
108
+
109
+
110
+ class ModelList(OpenAIBaseModel):
111
+ object: str = "list"
112
+ data: List[ModelCard] = Field(default_factory=list)
113
+
114
+
115
+ class PromptTokenUsageInfo(OpenAIBaseModel):
116
+ cached_tokens: Optional[int] = None
117
+
118
+
119
+ class UsageInfo(OpenAIBaseModel):
120
+ prompt_tokens: int = 0
121
+ total_tokens: int = 0
122
+ completion_tokens: Optional[int] = 0
123
+ prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
124
+
125
+
126
+ class RequestResponseMetadata(BaseModel):
127
+ request_id: str
128
+ final_usage_info: Optional[UsageInfo] = None
129
+
130
+
131
+ class JsonSchemaResponseFormat(OpenAIBaseModel):
132
+ name: str
133
+ description: Optional[str] = None
134
+ # schema is the field in openai but that causes conflicts with pydantic so
135
+ # instead use json_schema with an alias
136
+ json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
137
+ strict: Optional[bool] = None
138
+
139
+
140
+ class ResponseFormat(OpenAIBaseModel):
141
+ # type must be "json_schema", "json_object" or "text"
142
+ type: Literal["text", "json_object", "json_schema"]
143
+ json_schema: Optional[JsonSchemaResponseFormat] = None
144
+
145
+
146
+ class StreamOptions(OpenAIBaseModel):
147
+ include_usage: Optional[bool] = True
148
+ continuous_usage_stats: Optional[bool] = False
149
+
150
+
151
+ class FunctionDefinition(OpenAIBaseModel):
152
+ name: str
153
+ description: Optional[str] = None
154
+ parameters: Optional[Dict[str, Any]] = None
155
+
156
+
157
+ class ChatCompletionToolsParam(OpenAIBaseModel):
158
+ type: Literal["function"] = "function"
159
+ function: FunctionDefinition
160
+
161
+
162
+ class ChatCompletionNamedFunction(OpenAIBaseModel):
163
+ name: str
164
+
165
+
166
+ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
167
+ function: ChatCompletionNamedFunction
168
+ type: Literal["function"] = "function"
169
+
170
+
171
+ class LogitsProcessorConstructor(BaseModel):
172
+ qualname: str
173
+ args: Optional[List[Any]] = None
174
+ kwargs: Optional[Dict[str, Any]] = None
175
+
176
+
177
+ LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
178
+
179
+
180
+ def get_logits_processors(processors: Optional[LogitsProcessors],
181
+ pattern: Optional[str]) -> Optional[List[Any]]:
182
+ if processors and pattern:
183
+ logits_processors = []
184
+ for processor in processors:
185
+ qualname = processor if isinstance(processor,
186
+ str) else processor.qualname
187
+ if not re.match(pattern, qualname):
188
+ raise ValueError(
189
+ f"Logits processor '{qualname}' is not allowed by this "
190
+ "server. See --logits-processor-pattern engine argument "
191
+ "for more information.")
192
+ try:
193
+ logits_processor = resolve_obj_by_qualname(qualname)
194
+ except Exception as e:
195
+ raise ValueError(
196
+ f"Logits processor '{qualname}' could not be resolved: {e}"
197
+ ) from e
198
+ if isinstance(processor, LogitsProcessorConstructor):
199
+ logits_processor = logits_processor(*processor.args or [],
200
+ **processor.kwargs or {})
201
+ logits_processors.append(logits_processor)
202
+ return logits_processors
203
+ elif processors:
204
+ raise ValueError(
205
+ "The `logits_processors` argument is not supported by this "
206
+ "server. See --logits-processor-pattern engine argugment "
207
+ "for more information.")
208
+ return None
209
+
210
+
211
+ class ChatCompletionRequest(OpenAIBaseModel):
212
+ # Ordered by official OpenAI API documentation
213
+ # https://platform.openai.com/docs/api-reference/chat/create
214
+ messages: List[ChatCompletionMessageParam]
215
+ model: str
216
+ frequency_penalty: Optional[float] = 0.0
217
+ logit_bias: Optional[Dict[str, float]] = None
218
+ logprobs: Optional[bool] = False
219
+ top_logprobs: Optional[int] = 0
220
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
221
+ max_tokens: Optional[int] = Field(
222
+ default=None,
223
+ deprecated=
224
+ 'max_tokens is deprecated in favor of the max_completion_tokens field')
225
+ max_completion_tokens: Optional[int] = None
226
+ n: Optional[int] = 1
227
+ presence_penalty: Optional[float] = 0.0
228
+ response_format: Optional[ResponseFormat] = None
229
+ seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
230
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
231
+ stream: Optional[bool] = False
232
+ stream_options: Optional[StreamOptions] = None
233
+ temperature: Optional[float] = None
234
+ top_p: Optional[float] = None
235
+ tools: Optional[List[ChatCompletionToolsParam]] = None
236
+ tool_choice: Optional[Union[Literal["none"], Literal["auto"],
237
+ ChatCompletionNamedToolChoiceParam]] = "none"
238
+
239
+ # NOTE this will be ignored by VLLM -- the model determines the behavior
240
+ parallel_tool_calls: Optional[bool] = False
241
+ user: Optional[str] = None
242
+
243
+ # doc: begin-chat-completion-sampling-params
244
+ best_of: Optional[int] = None
245
+ use_beam_search: bool = False
246
+ top_k: Optional[int] = None
247
+ min_p: Optional[float] = None
248
+ repetition_penalty: Optional[float] = None
249
+ length_penalty: float = 1.0
250
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
251
+ include_stop_str_in_output: bool = False
252
+ ignore_eos: bool = False
253
+ min_tokens: int = 0
254
+ skip_special_tokens: bool = True
255
+ spaces_between_special_tokens: bool = True
256
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
257
+ prompt_logprobs: Optional[int] = None
258
+ # doc: end-chat-completion-sampling-params
259
+
260
+ # doc: begin-chat-completion-extra-params
261
+ echo: bool = Field(
262
+ default=False,
263
+ description=(
264
+ "If true, the new message will be prepended with the last message "
265
+ "if they belong to the same role."),
266
+ )
267
+ add_generation_prompt: bool = Field(
268
+ default=True,
269
+ description=
270
+ ("If true, the generation prompt will be added to the chat template. "
271
+ "This is a parameter used by chat template in tokenizer config of the "
272
+ "model."),
273
+ )
274
+ continue_final_message: bool = Field(
275
+ default=False,
276
+ description=
277
+ ("If this is set, the chat will be formatted so that the final "
278
+ "message in the chat is open-ended, without any EOS tokens. The "
279
+ "model will continue this message rather than starting a new one. "
280
+ "This allows you to \"prefill\" part of the model's response for it. "
281
+ "Cannot be used at the same time as `add_generation_prompt`."),
282
+ )
283
+ add_special_tokens: bool = Field(
284
+ default=False,
285
+ description=(
286
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
287
+ "on top of what is added by the chat template. "
288
+ "For most models, the chat template takes care of adding the "
289
+ "special tokens so this should be set to false (as is the "
290
+ "default)."),
291
+ )
292
+ documents: Optional[List[Dict[str, str]]] = Field(
293
+ default=None,
294
+ description=
295
+ ("A list of dicts representing documents that will be accessible to "
296
+ "the model if it is performing RAG (retrieval-augmented generation)."
297
+ " If the template does not support RAG, this argument will have no "
298
+ "effect. We recommend that each document should be a dict containing "
299
+ "\"title\" and \"text\" keys."),
300
+ )
301
+ chat_template: Optional[str] = Field(
302
+ default=None,
303
+ description=(
304
+ "A Jinja template to use for this conversion. "
305
+ "As of transformers v4.44, default chat template is no longer "
306
+ "allowed, so you must provide a chat template if the tokenizer "
307
+ "does not define one."),
308
+ )
309
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
310
+ default=None,
311
+ description=("Additional kwargs to pass to the template renderer. "
312
+ "Will be accessible by the chat template."),
313
+ )
314
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
315
+ default=None,
316
+ description=("If specified, the output will follow the JSON schema."),
317
+ )
318
+ guided_regex: Optional[str] = Field(
319
+ default=None,
320
+ description=(
321
+ "If specified, the output will follow the regex pattern."),
322
+ )
323
+ guided_choice: Optional[List[str]] = Field(
324
+ default=None,
325
+ description=(
326
+ "If specified, the output will be exactly one of the choices."),
327
+ )
328
+ guided_grammar: Optional[str] = Field(
329
+ default=None,
330
+ description=(
331
+ "If specified, the output will follow the context free grammar."),
332
+ )
333
+ guided_decoding_backend: Optional[str] = Field(
334
+ default=None,
335
+ description=(
336
+ "If specified, will override the default guided decoding backend "
337
+ "of the server for this specific request. If set, must be either "
338
+ "'outlines' / 'lm-format-enforcer'"))
339
+ guided_whitespace_pattern: Optional[str] = Field(
340
+ default=None,
341
+ description=(
342
+ "If specified, will override the default whitespace pattern "
343
+ "for guided json decoding."))
344
+ priority: int = Field(
345
+ default=0,
346
+ description=(
347
+ "The priority of the request (lower means earlier handling; "
348
+ "default: 0). Any priority other than 0 will raise an error "
349
+ "if the served model does not use priority scheduling."))
350
+ request_id: str = Field(
351
+ default_factory=lambda: f"{random_uuid()}",
352
+ description=(
353
+ "The request_id related to this request. If the caller does "
354
+ "not set it, a random_uuid will be generated. This id is used "
355
+ "through out the inference process and return in response."))
356
+ logits_processors: Optional[LogitsProcessors] = Field(
357
+ default=None,
358
+ description=(
359
+ "A list of either qualified names of logits processors, or "
360
+ "constructor objects, to apply when sampling. A constructor is "
361
+ "a JSON object with a required 'qualname' field specifying the "
362
+ "qualified name of the processor class/factory, and optional "
363
+ "'args' and 'kwargs' fields containing positional and keyword "
364
+ "arguments. For example: {'qualname': "
365
+ "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
366
+ "{'param': 'value'}}."))
367
+
368
+ # doc: end-chat-completion-extra-params
369
+
370
+ # Default sampling parameters for chat completion requests
371
+ _DEFAULT_SAMPLING_PARAMS: dict = {
372
+ "repetition_penalty": 1.0,
373
+ "temperature": 1.0,
374
+ "top_p": 1.0,
375
+ "top_k": -1,
376
+ "min_p": 0.0,
377
+ }
378
+
379
+ def to_beam_search_params(
380
+ self,
381
+ default_max_tokens: int,
382
+ default_sampling_params: Optional[dict] = None
383
+ ) -> BeamSearchParams:
384
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
385
+ max_tokens = self.max_completion_tokens or self.max_tokens
386
+
387
+ if default_sampling_params is None:
388
+ default_sampling_params = {}
389
+ n = self.n if self.n is not None else 1
390
+
391
+ # Use minimum of context window, user request & server limit.
392
+ max_tokens = min(
393
+ val for val in (default_max_tokens, max_tokens,
394
+ default_sampling_params.get("max_tokens", None))
395
+ if val is not None)
396
+
397
+ if (temperature := self.temperature) is None:
398
+ temperature = default_sampling_params.get(
399
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
400
+
401
+ return BeamSearchParams(
402
+ beam_width=n,
403
+ max_tokens=max_tokens,
404
+ ignore_eos=self.ignore_eos,
405
+ temperature=temperature,
406
+ length_penalty=self.length_penalty,
407
+ include_stop_str_in_output=self.include_stop_str_in_output)
408
+
409
+ def to_sampling_params(
410
+ self,
411
+ default_max_tokens: int,
412
+ logits_processor_pattern: Optional[str],
413
+ default_sampling_params: Optional[dict] = None) -> SamplingParams:
414
+ # TODO(#9845): remove max_tokens when field is removed from OpenAI API
415
+ max_tokens = self.max_completion_tokens or self.max_tokens
416
+
417
+ if default_sampling_params is None:
418
+ default_sampling_params = {}
419
+
420
+ # Use minimum of context window, user request & server limit.
421
+ max_tokens = min(
422
+ val for val in (default_max_tokens, max_tokens,
423
+ default_sampling_params.get("max_tokens", None))
424
+ if val is not None)
425
+
426
+ # Default parameters
427
+ if (repetition_penalty := self.repetition_penalty) is None:
428
+ repetition_penalty = default_sampling_params.get(
429
+ "repetition_penalty",
430
+ self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
431
+ )
432
+ if (temperature := self.temperature) is None:
433
+ temperature = default_sampling_params.get(
434
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
435
+ if (top_p := self.top_p) is None:
436
+ top_p = default_sampling_params.get(
437
+ "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
438
+ if (top_k := self.top_k) is None:
439
+ top_k = default_sampling_params.get(
440
+ "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
441
+ if (min_p := self.min_p) is None:
442
+ min_p = default_sampling_params.get(
443
+ "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
444
+
445
+ prompt_logprobs = self.prompt_logprobs
446
+ if prompt_logprobs is None and self.echo:
447
+ prompt_logprobs = self.top_logprobs
448
+
449
+ guided_json_object = None
450
+ if self.response_format is not None:
451
+ if self.response_format.type == "json_object":
452
+ guided_json_object = True
453
+ elif self.response_format.type == "json_schema":
454
+ json_schema = self.response_format.json_schema
455
+ assert json_schema is not None
456
+ self.guided_json = json_schema.json_schema
457
+ if self.guided_decoding_backend is None:
458
+ self.guided_decoding_backend = "xgrammar"
459
+
460
+ guided_decoding = GuidedDecodingParams.from_optional(
461
+ json=self._get_guided_json_from_tool() or self.guided_json,
462
+ regex=self.guided_regex,
463
+ choice=self.guided_choice,
464
+ grammar=self.guided_grammar,
465
+ json_object=guided_json_object,
466
+ backend=self.guided_decoding_backend,
467
+ whitespace_pattern=self.guided_whitespace_pattern)
468
+
469
+ return SamplingParams.from_optional(
470
+ n=self.n,
471
+ best_of=self.best_of,
472
+ presence_penalty=self.presence_penalty,
473
+ frequency_penalty=self.frequency_penalty,
474
+ repetition_penalty=repetition_penalty,
475
+ temperature=temperature,
476
+ top_p=top_p,
477
+ top_k=top_k,
478
+ min_p=min_p,
479
+ seed=self.seed,
480
+ stop=self.stop,
481
+ stop_token_ids=self.stop_token_ids,
482
+ logprobs=self.top_logprobs if self.logprobs else None,
483
+ prompt_logprobs=prompt_logprobs,
484
+ ignore_eos=self.ignore_eos,
485
+ max_tokens=max_tokens,
486
+ min_tokens=self.min_tokens,
487
+ skip_special_tokens=self.skip_special_tokens,
488
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
489
+ logits_processors=get_logits_processors(self.logits_processors,
490
+ logits_processor_pattern),
491
+ include_stop_str_in_output=self.include_stop_str_in_output,
492
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
493
+ output_kind=RequestOutputKind.DELTA if self.stream \
494
+ else RequestOutputKind.FINAL_ONLY,
495
+ guided_decoding=guided_decoding,
496
+ logit_bias=self.logit_bias)
497
+
498
+ def _get_guided_json_from_tool(
499
+ self) -> Optional[Union[str, dict, BaseModel]]:
500
+ # user has chosen to not use any tool
501
+ if self.tool_choice == "none" or self.tools is None:
502
+ return None
503
+
504
+ # user has chosen to use a named tool
505
+ if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
506
+ tool_name = self.tool_choice.function.name
507
+ tools = {tool.function.name: tool.function for tool in self.tools}
508
+ if tool_name not in tools:
509
+ raise ValueError(
510
+ f"Tool '{tool_name}' has not been passed in `tools`.")
511
+ tool = tools[tool_name]
512
+ return tool.parameters
513
+
514
+ return None
515
+
516
+ @model_validator(mode="before")
517
+ @classmethod
518
+ def validate_stream_options(cls, data):
519
+ if data.get("stream_options") and not data.get("stream"):
520
+ raise ValueError(
521
+ "Stream options can only be defined when `stream=True`.")
522
+
523
+ return data
524
+
525
+ @model_validator(mode="before")
526
+ @classmethod
527
+ def check_logprobs(cls, data):
528
+ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
529
+ if data.get("stream") and prompt_logprobs > 0:
530
+ raise ValueError(
531
+ "`prompt_logprobs` are not available when `stream=True`.")
532
+
533
+ if prompt_logprobs < 0:
534
+ raise ValueError("`prompt_logprobs` must be a positive value.")
535
+
536
+ if (top_logprobs := data.get("top_logprobs")) is not None:
537
+ if top_logprobs < 0:
538
+ raise ValueError("`top_logprobs` must be a positive value.")
539
+
540
+ if not data.get("logprobs"):
541
+ raise ValueError(
542
+ "when using `top_logprobs`, `logprobs` must be set to true."
543
+ )
544
+
545
+ return data
546
+
547
+ @model_validator(mode="before")
548
+ @classmethod
549
+ def check_guided_decoding_count(cls, data):
550
+ if isinstance(data, ValueError):
551
+ raise data
552
+
553
+ guide_count = sum([
554
+ "guided_json" in data and data["guided_json"] is not None,
555
+ "guided_regex" in data and data["guided_regex"] is not None,
556
+ "guided_choice" in data and data["guided_choice"] is not None
557
+ ])
558
+ # you can only use one kind of guided decoding
559
+ if guide_count > 1:
560
+ raise ValueError(
561
+ "You can only use one kind of guided decoding "
562
+ "('guided_json', 'guided_regex' or 'guided_choice').")
563
+ # you can only either use guided decoding or tools, not both
564
+ if guide_count > 1 and data.get("tool_choice",
565
+ "none") not in ("none", "auto"):
566
+ raise ValueError(
567
+ "You can only either use guided decoding or tools, not both.")
568
+ return data
569
+
570
+ @model_validator(mode="before")
571
+ @classmethod
572
+ def check_tool_usage(cls, data):
573
+
574
+ # if "tool_choice" is not specified but tools are provided,
575
+ # default to "auto" tool_choice
576
+ if "tool_choice" not in data and data.get("tools"):
577
+ data["tool_choice"] = "auto"
578
+
579
+ # if "tool_choice" is "none" -- ignore tools if present
580
+ if "tool_choice" in data and data["tool_choice"] == "none":
581
+ # ensure that no tools are present
582
+ data.pop("tools", None)
583
+ return data
584
+
585
+ # if "tool_choice" is specified -- validation
586
+ if "tool_choice" in data:
587
+
588
+ # ensure that if "tool choice" is specified, tools are present
589
+ if "tools" not in data or data["tools"] is None:
590
+ raise ValueError(
591
+ "When using `tool_choice`, `tools` must be set.")
592
+
593
+ # make sure that tool choice is either a named tool
594
+ # OR that it's set to "auto"
595
+ if data["tool_choice"] != "auto" and not isinstance(
596
+ data["tool_choice"], dict):
597
+ raise ValueError(
598
+ "`tool_choice` must either be a named tool, \"auto\", "
599
+ "or \"none\".")
600
+
601
+ # ensure that if "tool_choice" is specified as an object,
602
+ # it matches a valid tool
603
+ if isinstance(data["tool_choice"], dict):
604
+ valid_tool = False
605
+ specified_function = data["tool_choice"].get("function")
606
+ if not specified_function:
607
+ raise ValueError(
608
+ "Expected field `function` in `tool_choice`."
609
+ " Correct usage: `{\"type\": \"function\","
610
+ " \"function\": {\"name\": \"my_function\"}}`")
611
+ specified_function_name = specified_function.get("name")
612
+ if not specified_function_name:
613
+ raise ValueError(
614
+ "Expected field `name` in `function` in `tool_choice`."
615
+ "Correct usage: `{\"type\": \"function\", "
616
+ "\"function\": {\"name\": \"my_function\"}}`")
617
+ for tool in data["tools"]:
618
+ if tool["function"]["name"] == specified_function_name:
619
+ valid_tool = True
620
+ break
621
+ if not valid_tool:
622
+ raise ValueError(
623
+ "The tool specified in `tool_choice` does not match any"
624
+ " of the specified `tools`")
625
+ return data
626
+
627
+ @model_validator(mode="before")
628
+ @classmethod
629
+ def check_generation_prompt(cls, data):
630
+ if data.get("continue_final_message") and data.get(
631
+ "add_generation_prompt"):
632
+ raise ValueError("Cannot set both `continue_final_message` and "
633
+ "`add_generation_prompt` to True.")
634
+ return data
635
+
636
+
637
+ class CompletionRequest(OpenAIBaseModel):
638
+ # Ordered by official OpenAI API documentation
639
+ # https://platform.openai.com/docs/api-reference/completions/create
640
+ model: str
641
+ prompt: Union[List[int], List[List[int]], str, List[str]]
642
+ best_of: Optional[int] = None
643
+ echo: Optional[bool] = False
644
+ frequency_penalty: Optional[float] = 0.0
645
+ logit_bias: Optional[Dict[str, float]] = None
646
+ logprobs: Optional[int] = None
647
+ max_tokens: Optional[int] = 16
648
+ n: int = 1
649
+ presence_penalty: Optional[float] = 0.0
650
+ seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
651
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
652
+ stream: Optional[bool] = False
653
+ stream_options: Optional[StreamOptions] = None
654
+ suffix: Optional[str] = None
655
+ temperature: Optional[float] = None
656
+ top_p: Optional[float] = None
657
+ user: Optional[str] = None
658
+
659
+ # doc: begin-completion-sampling-params
660
+ use_beam_search: bool = False
661
+ top_k: Optional[int] = None
662
+ min_p: Optional[float] = None
663
+ repetition_penalty: Optional[float] = None
664
+ length_penalty: float = 1.0
665
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
666
+ include_stop_str_in_output: bool = False
667
+ ignore_eos: bool = False
668
+ min_tokens: int = 0
669
+ skip_special_tokens: bool = True
670
+ spaces_between_special_tokens: bool = True
671
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
672
+ allowed_token_ids: Optional[List[int]] = None
673
+ prompt_logprobs: Optional[int] = None
674
+ # doc: end-completion-sampling-params
675
+
676
+ # doc: begin-completion-extra-params
677
+ add_special_tokens: bool = Field(
678
+ default=True,
679
+ description=(
680
+ "If true (the default), special tokens (e.g. BOS) will be added to "
681
+ "the prompt."),
682
+ )
683
+ response_format: Optional[ResponseFormat] = Field(
684
+ default=None,
685
+ description=
686
+ ("Similar to chat completion, this parameter specifies the format of "
687
+ "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
688
+ "{'type': 'text' } is supported."),
689
+ )
690
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
691
+ default=None,
692
+ description="If specified, the output will follow the JSON schema.",
693
+ )
694
+ guided_regex: Optional[str] = Field(
695
+ default=None,
696
+ description=(
697
+ "If specified, the output will follow the regex pattern."),
698
+ )
699
+ guided_choice: Optional[List[str]] = Field(
700
+ default=None,
701
+ description=(
702
+ "If specified, the output will be exactly one of the choices."),
703
+ )
704
+ guided_grammar: Optional[str] = Field(
705
+ default=None,
706
+ description=(
707
+ "If specified, the output will follow the context free grammar."),
708
+ )
709
+ guided_decoding_backend: Optional[str] = Field(
710
+ default=None,
711
+ description=(
712
+ "If specified, will override the default guided decoding backend "
713
+ "of the server for this specific request. If set, must be one of "
714
+ "'outlines' / 'lm-format-enforcer'"))
715
+ guided_whitespace_pattern: Optional[str] = Field(
716
+ default=None,
717
+ description=(
718
+ "If specified, will override the default whitespace pattern "
719
+ "for guided json decoding."))
720
+ priority: int = Field(
721
+ default=0,
722
+ description=(
723
+ "The priority of the request (lower means earlier handling; "
724
+ "default: 0). Any priority other than 0 will raise an error "
725
+ "if the served model does not use priority scheduling."))
726
+ logits_processors: Optional[LogitsProcessors] = Field(
727
+ default=None,
728
+ description=(
729
+ "A list of either qualified names of logits processors, or "
730
+ "constructor objects, to apply when sampling. A constructor is "
731
+ "a JSON object with a required 'qualname' field specifying the "
732
+ "qualified name of the processor class/factory, and optional "
733
+ "'args' and 'kwargs' fields containing positional and keyword "
734
+ "arguments. For example: {'qualname': "
735
+ "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
736
+ "{'param': 'value'}}."))
737
+
738
+ # doc: end-completion-extra-params
739
+
740
+ # Default sampling parameters for completion requests
741
+ _DEFAULT_SAMPLING_PARAMS: dict = {
742
+ "repetition_penalty": 1.0,
743
+ "temperature": 1.0,
744
+ "top_p": 1.0,
745
+ "top_k": -1,
746
+ "min_p": 0.0,
747
+ }
748
+
749
+ def to_beam_search_params(
750
+ self,
751
+ default_max_tokens: int,
752
+ default_sampling_params: Optional[dict] = None
753
+ ) -> BeamSearchParams:
754
+ max_tokens = self.max_tokens
755
+
756
+ if default_sampling_params is None:
757
+ default_sampling_params = {}
758
+ n = self.n if self.n is not None else 1
759
+
760
+ # Use minimum of context window, user request & server limit.
761
+ max_tokens = min(
762
+ val for val in (default_max_tokens, max_tokens,
763
+ default_sampling_params.get("max_tokens", None))
764
+ if val is not None)
765
+
766
+ if (temperature := self.temperature) is None:
767
+ temperature = default_sampling_params.get("temperature", 1.0)
768
+
769
+ return BeamSearchParams(
770
+ beam_width=n,
771
+ max_tokens=max_tokens,
772
+ ignore_eos=self.ignore_eos,
773
+ temperature=temperature,
774
+ length_penalty=self.length_penalty,
775
+ include_stop_str_in_output=self.include_stop_str_in_output)
776
+
777
+ def to_sampling_params(
778
+ self,
779
+ default_max_tokens: int,
780
+ logits_processor_pattern: Optional[str],
781
+ default_sampling_params: Optional[dict] = None) -> SamplingParams:
782
+ max_tokens = self.max_tokens
783
+
784
+ if default_sampling_params is None:
785
+ default_sampling_params = {}
786
+
787
+ # Use minimum of context window, user request & server limit.
788
+ max_tokens = min(
789
+ val for val in (default_max_tokens, max_tokens,
790
+ default_sampling_params.get("max_tokens", None))
791
+ if val is not None)
792
+
793
+ # Default parameters
794
+ if (repetition_penalty := self.repetition_penalty) is None:
795
+ repetition_penalty = default_sampling_params.get(
796
+ "repetition_penalty",
797
+ self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
798
+ )
799
+ if (temperature := self.temperature) is None:
800
+ temperature = default_sampling_params.get(
801
+ "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
802
+ if (top_p := self.top_p) is None:
803
+ top_p = default_sampling_params.get(
804
+ "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
805
+ if (top_k := self.top_k) is None:
806
+ top_k = default_sampling_params.get(
807
+ "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
808
+ if (min_p := self.min_p) is None:
809
+ min_p = default_sampling_params.get(
810
+ "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
811
+
812
+ prompt_logprobs = self.prompt_logprobs
813
+ if prompt_logprobs is None and self.echo:
814
+ prompt_logprobs = self.logprobs
815
+
816
+ echo_without_generation = self.echo and self.max_tokens == 0
817
+
818
+ guided_json_object = None
819
+ if (self.response_format is not None
820
+ and self.response_format.type == "json_object"):
821
+ guided_json_object = True
822
+
823
+ guided_decoding = GuidedDecodingParams.from_optional(
824
+ json=self.guided_json,
825
+ regex=self.guided_regex,
826
+ choice=self.guided_choice,
827
+ grammar=self.guided_grammar,
828
+ json_object=guided_json_object,
829
+ backend=self.guided_decoding_backend,
830
+ whitespace_pattern=self.guided_whitespace_pattern)
831
+
832
+ return SamplingParams.from_optional(
833
+ n=self.n,
834
+ best_of=self.best_of,
835
+ presence_penalty=self.presence_penalty,
836
+ frequency_penalty=self.frequency_penalty,
837
+ repetition_penalty=repetition_penalty,
838
+ temperature=temperature,
839
+ top_p=top_p,
840
+ top_k=top_k,
841
+ min_p=min_p,
842
+ seed=self.seed,
843
+ stop=self.stop,
844
+ stop_token_ids=self.stop_token_ids,
845
+ logprobs=self.logprobs,
846
+ ignore_eos=self.ignore_eos,
847
+ max_tokens=max_tokens if not echo_without_generation else 1,
848
+ min_tokens=self.min_tokens,
849
+ prompt_logprobs=prompt_logprobs,
850
+ skip_special_tokens=self.skip_special_tokens,
851
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
852
+ include_stop_str_in_output=self.include_stop_str_in_output,
853
+ logits_processors=get_logits_processors(self.logits_processors,
854
+ logits_processor_pattern),
855
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
856
+ output_kind=RequestOutputKind.DELTA if self.stream \
857
+ else RequestOutputKind.FINAL_ONLY,
858
+ guided_decoding=guided_decoding,
859
+ logit_bias=self.logit_bias,
860
+ allowed_token_ids=self.allowed_token_ids)
861
+
862
+ @model_validator(mode="before")
863
+ @classmethod
864
+ def check_guided_decoding_count(cls, data):
865
+ guide_count = sum([
866
+ "guided_json" in data and data["guided_json"] is not None,
867
+ "guided_regex" in data and data["guided_regex"] is not None,
868
+ "guided_choice" in data and data["guided_choice"] is not None
869
+ ])
870
+ if guide_count > 1:
871
+ raise ValueError(
872
+ "You can only use one kind of guided decoding "
873
+ "('guided_json', 'guided_regex' or 'guided_choice').")
874
+ return data
875
+
876
+ @model_validator(mode="before")
877
+ @classmethod
878
+ def check_logprobs(cls, data):
879
+ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
880
+ if data.get("stream") and prompt_logprobs > 0:
881
+ raise ValueError(
882
+ "`prompt_logprobs` are not available when `stream=True`.")
883
+
884
+ if prompt_logprobs < 0:
885
+ raise ValueError("`prompt_logprobs` must be a positive value.")
886
+
887
+ if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
888
+ raise ValueError("`logprobs` must be a positive value.")
889
+
890
+ return data
891
+
892
+ @model_validator(mode="before")
893
+ @classmethod
894
+ def validate_stream_options(cls, data):
895
+ if data.get("stream_options") and not data.get("stream"):
896
+ raise ValueError(
897
+ "Stream options can only be defined when `stream=True`.")
898
+
899
+ return data
900
+
901
+
902
+ class EmbeddingCompletionRequest(OpenAIBaseModel):
903
+ # Ordered by official OpenAI API documentation
904
+ # https://platform.openai.com/docs/api-reference/embeddings
905
+ model: str
906
+ input: Union[List[int], List[List[int]], str, List[str]]
907
+ encoding_format: Literal["float", "base64"] = "float"
908
+ dimensions: Optional[int] = None
909
+ user: Optional[str] = None
910
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
911
+
912
+ # doc: begin-embedding-pooling-params
913
+ additional_data: Optional[Any] = None
914
+ # doc: end-embedding-pooling-params
915
+
916
+ # doc: begin-embedding-extra-params
917
+ add_special_tokens: bool = Field(
918
+ default=True,
919
+ description=(
920
+ "If true (the default), special tokens (e.g. BOS) will be added to "
921
+ "the prompt."),
922
+ )
923
+ priority: int = Field(
924
+ default=0,
925
+ description=(
926
+ "The priority of the request (lower means earlier handling; "
927
+ "default: 0). Any priority other than 0 will raise an error "
928
+ "if the served model does not use priority scheduling."))
929
+
930
+ # doc: end-embedding-extra-params
931
+
932
+ def to_pooling_params(self):
933
+ return PoolingParams(additional_data=self.additional_data)
934
+
935
+
936
+ class EmbeddingChatRequest(OpenAIBaseModel):
937
+ model: str
938
+ messages: List[ChatCompletionMessageParam]
939
+
940
+ encoding_format: Literal["float", "base64"] = "float"
941
+ dimensions: Optional[int] = None
942
+ user: Optional[str] = None
943
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
944
+
945
+ # doc: begin-chat-embedding-pooling-params
946
+ additional_data: Optional[Any] = None
947
+ # doc: end-chat-embedding-pooling-params
948
+
949
+ # doc: begin-chat-embedding-extra-params
950
+ add_special_tokens: bool = Field(
951
+ default=False,
952
+ description=(
953
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
954
+ "on top of what is added by the chat template. "
955
+ "For most models, the chat template takes care of adding the "
956
+ "special tokens so this should be set to false (as is the "
957
+ "default)."),
958
+ )
959
+ chat_template: Optional[str] = Field(
960
+ default=None,
961
+ description=(
962
+ "A Jinja template to use for this conversion. "
963
+ "As of transformers v4.44, default chat template is no longer "
964
+ "allowed, so you must provide a chat template if the tokenizer "
965
+ "does not define one."),
966
+ )
967
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
968
+ default=None,
969
+ description=("Additional kwargs to pass to the template renderer. "
970
+ "Will be accessible by the chat template."),
971
+ )
972
+ priority: int = Field(
973
+ default=0,
974
+ description=(
975
+ "The priority of the request (lower means earlier handling; "
976
+ "default: 0). Any priority other than 0 will raise an error "
977
+ "if the served model does not use priority scheduling."))
978
+ # doc: end-chat-embedding-extra-params
979
+
980
+ @model_validator(mode="before")
981
+ @classmethod
982
+ def check_generation_prompt(cls, data):
983
+ if data.get("continue_final_message") and data.get(
984
+ "add_generation_prompt"):
985
+ raise ValueError("Cannot set both `continue_final_message` and "
986
+ "`add_generation_prompt` to True.")
987
+ return data
988
+
989
+ def to_pooling_params(self):
990
+ return PoolingParams(additional_data=self.additional_data)
991
+
992
+
993
+ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
994
+
995
+ PoolingCompletionRequest = EmbeddingCompletionRequest
996
+ PoolingChatRequest = EmbeddingChatRequest
997
+ PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
998
+
999
+
1000
+ class ScoreRequest(OpenAIBaseModel):
1001
+ model: str
1002
+ text_1: Union[List[str], str]
1003
+ text_2: Union[List[str], str]
1004
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1005
+
1006
+ # doc: begin-score-pooling-params
1007
+ additional_data: Optional[Any] = None
1008
+ # doc: end-score-pooling-params
1009
+
1010
+ # doc: begin-score-extra-params
1011
+ priority: int = Field(
1012
+ default=0,
1013
+ description=(
1014
+ "The priority of the request (lower means earlier handling; "
1015
+ "default: 0). Any priority other than 0 will raise an error "
1016
+ "if the served model does not use priority scheduling."))
1017
+
1018
+ # doc: end-score-extra-params
1019
+
1020
+ def to_pooling_params(self):
1021
+ return PoolingParams(additional_data=self.additional_data)
1022
+
1023
+
1024
+ class RerankRequest(OpenAIBaseModel):
1025
+ model: str
1026
+ query: str
1027
+ documents: List[str]
1028
+ top_n: int = Field(default_factory=lambda: 0)
1029
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1030
+
1031
+ # doc: begin-rerank-pooling-params
1032
+ additional_data: Optional[Any] = None
1033
+ # doc: end-rerank-pooling-params
1034
+
1035
+ # doc: begin-rerank-extra-params
1036
+ priority: int = Field(
1037
+ default=0,
1038
+ description=(
1039
+ "The priority of the request (lower means earlier handling; "
1040
+ "default: 0). Any priority other than 0 will raise an error "
1041
+ "if the served model does not use priority scheduling."))
1042
+
1043
+ # doc: end-rerank-extra-params
1044
+
1045
+ def to_pooling_params(self):
1046
+ return PoolingParams(additional_data=self.additional_data)
1047
+
1048
+
1049
+ class RerankDocument(BaseModel):
1050
+ text: str
1051
+
1052
+
1053
+ class RerankResult(BaseModel):
1054
+ index: int
1055
+ document: RerankDocument
1056
+ relevance_score: float
1057
+
1058
+
1059
+ class RerankUsage(BaseModel):
1060
+ total_tokens: int
1061
+
1062
+
1063
+ class RerankResponse(OpenAIBaseModel):
1064
+ id: str
1065
+ model: str
1066
+ usage: RerankUsage
1067
+ results: List[RerankResult]
1068
+
1069
+
1070
+ class CompletionLogProbs(OpenAIBaseModel):
1071
+ text_offset: List[int] = Field(default_factory=list)
1072
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
1073
+ tokens: List[str] = Field(default_factory=list)
1074
+ top_logprobs: List[Optional[Dict[str,
1075
+ float]]] = Field(default_factory=list)
1076
+
1077
+
1078
+ class CompletionResponseChoice(OpenAIBaseModel):
1079
+ index: int
1080
+ text: str
1081
+ logprobs: Optional[CompletionLogProbs] = None
1082
+ finish_reason: Optional[str] = None
1083
+ stop_reason: Optional[Union[int, str]] = Field(
1084
+ default=None,
1085
+ description=(
1086
+ "The stop string or token id that caused the completion "
1087
+ "to stop, None if the completion finished for some other reason "
1088
+ "including encountering the EOS token"),
1089
+ )
1090
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
1091
+
1092
+
1093
+ class CompletionResponse(OpenAIBaseModel):
1094
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1095
+ object: str = "text_completion"
1096
+ created: int = Field(default_factory=lambda: int(time.time()))
1097
+ model: str
1098
+ choices: List[CompletionResponseChoice]
1099
+ usage: UsageInfo
1100
+
1101
+
1102
+ class CompletionResponseStreamChoice(OpenAIBaseModel):
1103
+ index: int
1104
+ text: str
1105
+ logprobs: Optional[CompletionLogProbs] = None
1106
+ finish_reason: Optional[str] = None
1107
+ stop_reason: Optional[Union[int, str]] = Field(
1108
+ default=None,
1109
+ description=(
1110
+ "The stop string or token id that caused the completion "
1111
+ "to stop, None if the completion finished for some other reason "
1112
+ "including encountering the EOS token"),
1113
+ )
1114
+
1115
+
1116
+ class CompletionStreamResponse(OpenAIBaseModel):
1117
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1118
+ object: str = "text_completion"
1119
+ created: int = Field(default_factory=lambda: int(time.time()))
1120
+ model: str
1121
+ choices: List[CompletionResponseStreamChoice]
1122
+ usage: Optional[UsageInfo] = Field(default=None)
1123
+
1124
+
1125
+ class EmbeddingResponseData(OpenAIBaseModel):
1126
+ index: int
1127
+ object: str = "embedding"
1128
+ embedding: Union[List[float], str]
1129
+
1130
+
1131
+ class EmbeddingResponse(OpenAIBaseModel):
1132
+ id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1133
+ object: str = "list"
1134
+ created: int = Field(default_factory=lambda: int(time.time()))
1135
+ model: str
1136
+ data: List[EmbeddingResponseData]
1137
+ usage: UsageInfo
1138
+
1139
+
1140
+ class PoolingResponseData(OpenAIBaseModel):
1141
+ index: int
1142
+ object: str = "pooling"
1143
+ data: Union[List[List[float]], List[float], str]
1144
+
1145
+
1146
+ class PoolingResponse(OpenAIBaseModel):
1147
+ id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
1148
+ object: str = "list"
1149
+ created: int = Field(default_factory=lambda: int(time.time()))
1150
+ model: str
1151
+ data: List[PoolingResponseData]
1152
+ usage: UsageInfo
1153
+
1154
+
1155
+ class ScoreResponseData(OpenAIBaseModel):
1156
+ index: int
1157
+ object: str = "score"
1158
+ score: float
1159
+
1160
+
1161
+ class ScoreResponse(OpenAIBaseModel):
1162
+ id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1163
+ object: str = "list"
1164
+ created: int = Field(default_factory=lambda: int(time.time()))
1165
+ model: str
1166
+ data: List[ScoreResponseData]
1167
+ usage: UsageInfo
1168
+
1169
+
1170
+ class FunctionCall(OpenAIBaseModel):
1171
+ name: str
1172
+ arguments: str
1173
+
1174
+
1175
+ class ToolCall(OpenAIBaseModel):
1176
+ id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1177
+ type: Literal["function"] = "function"
1178
+ function: FunctionCall
1179
+
1180
+
1181
+ class DeltaFunctionCall(BaseModel):
1182
+ name: Optional[str] = None
1183
+ arguments: Optional[str] = None
1184
+
1185
+
1186
+ # a tool call delta where everything is optional
1187
+ class DeltaToolCall(OpenAIBaseModel):
1188
+ id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1189
+ type: Literal["function"] = "function"
1190
+ index: int
1191
+ function: Optional[DeltaFunctionCall] = None
1192
+
1193
+
1194
+ class ExtractedToolCallInformation(BaseModel):
1195
+ # indicate if tools were called
1196
+ tools_called: bool
1197
+
1198
+ # extracted tool calls
1199
+ tool_calls: List[ToolCall]
1200
+
1201
+ # content - per OpenAI spec, content AND tool calls can be returned rarely
1202
+ # But some models will do this intentionally
1203
+ content: Optional[str] = None
1204
+
1205
+
1206
+ class ChatMessage(OpenAIBaseModel):
1207
+ role: str
1208
+ reasoning_content: Optional[str] = None
1209
+ content: Optional[str] = None
1210
+ tool_calls: List[ToolCall] = Field(default_factory=list)
1211
+
1212
+
1213
+ class ChatCompletionLogProb(OpenAIBaseModel):
1214
+ token: str
1215
+ logprob: float = -9999.0
1216
+ bytes: Optional[List[int]] = None
1217
+
1218
+
1219
+ class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1220
+ top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
1221
+
1222
+
1223
+ class ChatCompletionLogProbs(OpenAIBaseModel):
1224
+ content: Optional[List[ChatCompletionLogProbsContent]] = None
1225
+
1226
+
1227
+ class ChatCompletionResponseChoice(OpenAIBaseModel):
1228
+ index: int
1229
+ message: ChatMessage
1230
+ logprobs: Optional[ChatCompletionLogProbs] = None
1231
+ # per OpenAI spec this is the default
1232
+ finish_reason: Optional[str] = "stop"
1233
+ # not part of the OpenAI spec but included in vLLM for legacy reasons
1234
+ stop_reason: Optional[Union[int, str]] = None
1235
+
1236
+
1237
+ class ChatCompletionResponse(OpenAIBaseModel):
1238
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1239
+ object: Literal["chat.completion"] = "chat.completion"
1240
+ created: int = Field(default_factory=lambda: int(time.time()))
1241
+ model: str
1242
+ choices: List[ChatCompletionResponseChoice]
1243
+ usage: UsageInfo
1244
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
1245
+
1246
+
1247
+ class DeltaMessage(OpenAIBaseModel):
1248
+ role: Optional[str] = None
1249
+ content: Optional[str] = None
1250
+ reasoning_content: Optional[str] = None
1251
+ tool_calls: List[DeltaToolCall] = Field(default_factory=list)
1252
+
1253
+
1254
+ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1255
+ index: int
1256
+ delta: DeltaMessage
1257
+ logprobs: Optional[ChatCompletionLogProbs] = None
1258
+ finish_reason: Optional[str] = None
1259
+ stop_reason: Optional[Union[int, str]] = None
1260
+
1261
+
1262
+ class ChatCompletionStreamResponse(OpenAIBaseModel):
1263
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1264
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1265
+ created: int = Field(default_factory=lambda: int(time.time()))
1266
+ model: str
1267
+ choices: List[ChatCompletionResponseStreamChoice]
1268
+ usage: Optional[UsageInfo] = Field(default=None)
1269
+
1270
+
1271
+ class BatchRequestInput(OpenAIBaseModel):
1272
+ """
1273
+ The per-line object of the batch input file.
1274
+
1275
+ NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
1276
+ """
1277
+
1278
+ # A developer-provided per-request id that will be used to match outputs to
1279
+ # inputs. Must be unique for each request in a batch.
1280
+ custom_id: str
1281
+
1282
+ # The HTTP method to be used for the request. Currently only POST is
1283
+ # supported.
1284
+ method: str
1285
+
1286
+ # The OpenAI API relative URL to be used for the request. Currently
1287
+ # /v1/chat/completions is supported.
1288
+ url: str
1289
+
1290
+ # The parameters of the request.
1291
+ body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
1292
+
1293
+ @field_validator('body', mode='plain')
1294
+ @classmethod
1295
+ def check_type_for_url(cls, value: Any, info: ValidationInfo):
1296
+ # Use url to disambiguate models
1297
+ url = info.data['url']
1298
+ if url == "/v1/chat/completions":
1299
+ return ChatCompletionRequest.model_validate(value)
1300
+ if url == "/v1/embeddings":
1301
+ return TypeAdapter(EmbeddingRequest).validate_python(value)
1302
+ if url == "/v1/score":
1303
+ return ScoreRequest.model_validate(value)
1304
+ return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
1305
+ ScoreRequest]).validate_python(value)
1306
+
1307
+
1308
+ class BatchResponseData(OpenAIBaseModel):
1309
+ # HTTP status code of the response.
1310
+ status_code: int = 200
1311
+
1312
+ # An unique identifier for the API request.
1313
+ request_id: str
1314
+
1315
+ # The body of the response.
1316
+ body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
1317
+ ScoreResponse]] = None
1318
+
1319
+
1320
+ class BatchRequestOutput(OpenAIBaseModel):
1321
+ """
1322
+ The per-line object of the batch output and error files
1323
+ """
1324
+
1325
+ id: str
1326
+
1327
+ # A developer-provided per-request id that will be used to match outputs to
1328
+ # inputs.
1329
+ custom_id: str
1330
+
1331
+ response: Optional[BatchResponseData]
1332
+
1333
+ # For requests that failed with a non-HTTP error, this will contain more
1334
+ # information on the cause of the failure.
1335
+ error: Optional[Any]
1336
+
1337
+
1338
+ class TokenizeCompletionRequest(OpenAIBaseModel):
1339
+ model: str
1340
+ prompt: str
1341
+
1342
+ add_special_tokens: bool = Field(
1343
+ default=True,
1344
+ description=(
1345
+ "If true (the default), special tokens (e.g. BOS) will be added to "
1346
+ "the prompt."),
1347
+ )
1348
+
1349
+
1350
+ class TokenizeChatRequest(OpenAIBaseModel):
1351
+ model: str
1352
+ messages: List[ChatCompletionMessageParam]
1353
+
1354
+ add_generation_prompt: bool = Field(
1355
+ default=True,
1356
+ description=
1357
+ ("If true, the generation prompt will be added to the chat template. "
1358
+ "This is a parameter used by chat template in tokenizer config of the "
1359
+ "model."),
1360
+ )
1361
+ continue_final_message: bool = Field(
1362
+ default=False,
1363
+ description=
1364
+ ("If this is set, the chat will be formatted so that the final "
1365
+ "message in the chat is open-ended, without any EOS tokens. The "
1366
+ "model will continue this message rather than starting a new one. "
1367
+ "This allows you to \"prefill\" part of the model's response for it. "
1368
+ "Cannot be used at the same time as `add_generation_prompt`."),
1369
+ )
1370
+ add_special_tokens: bool = Field(
1371
+ default=False,
1372
+ description=(
1373
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
1374
+ "on top of what is added by the chat template. "
1375
+ "For most models, the chat template takes care of adding the "
1376
+ "special tokens so this should be set to false (as is the "
1377
+ "default)."),
1378
+ )
1379
+ chat_template: Optional[str] = Field(
1380
+ default=None,
1381
+ description=(
1382
+ "A Jinja template to use for this conversion. "
1383
+ "As of transformers v4.44, default chat template is no longer "
1384
+ "allowed, so you must provide a chat template if the tokenizer "
1385
+ "does not define one."),
1386
+ )
1387
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
1388
+ default=None,
1389
+ description=("Additional kwargs to pass to the template renderer. "
1390
+ "Will be accessible by the chat template."),
1391
+ )
1392
+
1393
+ @model_validator(mode="before")
1394
+ @classmethod
1395
+ def check_generation_prompt(cls, data):
1396
+ if data.get("continue_final_message") and data.get(
1397
+ "add_generation_prompt"):
1398
+ raise ValueError("Cannot set both `continue_final_message` and "
1399
+ "`add_generation_prompt` to True.")
1400
+ return data
1401
+
1402
+
1403
+ TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
1404
+
1405
+
1406
+ class TokenizeResponse(OpenAIBaseModel):
1407
+ count: int
1408
+ max_model_len: int
1409
+ tokens: List[int]
1410
+
1411
+
1412
+ class DetokenizeRequest(OpenAIBaseModel):
1413
+ model: str
1414
+ tokens: List[int]
1415
+
1416
+
1417
+ class DetokenizeResponse(OpenAIBaseModel):
1418
+ prompt: str
1419
+
1420
+
1421
+ class LoadLoraAdapterRequest(BaseModel):
1422
+ lora_name: str
1423
+ lora_path: str
1424
+
1425
+
1426
+ class UnloadLoraAdapterRequest(BaseModel):
1427
+ lora_name: str
1428
+ lora_int_id: Optional[int] = Field(default=None)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/run_batch.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ from http import HTTPStatus
5
+ from io import StringIO
6
+ from typing import Awaitable, Callable, List, Optional
7
+
8
+ import aiohttp
9
+ import torch
10
+ from prometheus_client import start_http_server
11
+ from tqdm import tqdm
12
+
13
+ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
14
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
15
+ from vllm.entrypoints.logger import RequestLogger, logger
16
+ # yapf: disable
17
+ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
18
+ BatchRequestOutput,
19
+ BatchResponseData,
20
+ ChatCompletionResponse,
21
+ EmbeddingResponse, ErrorResponse,
22
+ ScoreResponse)
23
+ # yapf: enable
24
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
25
+ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
26
+ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
27
+ OpenAIServingModels)
28
+ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
29
+ from vllm.usage.usage_lib import UsageContext
30
+ from vllm.utils import FlexibleArgumentParser, random_uuid
31
+ from vllm.version import __version__ as VLLM_VERSION
32
+
33
+
34
+ def parse_args():
35
+ parser = FlexibleArgumentParser(
36
+ description="vLLM OpenAI-Compatible batch runner.")
37
+ parser.add_argument(
38
+ "-i",
39
+ "--input-file",
40
+ required=True,
41
+ type=str,
42
+ help=
43
+ "The path or url to a single input file. Currently supports local file "
44
+ "paths, or the http protocol (http or https). If a URL is specified, "
45
+ "the file should be available via HTTP GET.")
46
+ parser.add_argument(
47
+ "-o",
48
+ "--output-file",
49
+ required=True,
50
+ type=str,
51
+ help="The path or url to a single output file. Currently supports "
52
+ "local file paths, or web (http or https) urls. If a URL is specified,"
53
+ " the file should be available via HTTP PUT.")
54
+ parser.add_argument("--response-role",
55
+ type=nullable_str,
56
+ default="assistant",
57
+ help="The role name to return if "
58
+ "`request.add_generation_prompt=True`.")
59
+
60
+ parser = AsyncEngineArgs.add_cli_args(parser)
61
+
62
+ parser.add_argument('--max-log-len',
63
+ type=int,
64
+ default=None,
65
+ help='Max number of prompt characters or prompt '
66
+ 'ID numbers being printed in log.'
67
+ '\n\nDefault: Unlimited')
68
+
69
+ parser.add_argument("--enable-metrics",
70
+ action="store_true",
71
+ help="Enable Prometheus metrics")
72
+ parser.add_argument(
73
+ "--url",
74
+ type=str,
75
+ default="0.0.0.0",
76
+ help="URL to the Prometheus metrics server "
77
+ "(only needed if enable-metrics is set).",
78
+ )
79
+ parser.add_argument(
80
+ "--port",
81
+ type=int,
82
+ default=8000,
83
+ help="Port number for the Prometheus metrics server "
84
+ "(only needed if enable-metrics is set).",
85
+ )
86
+ parser.add_argument(
87
+ "--enable-prompt-tokens-details",
88
+ action='store_true',
89
+ default=False,
90
+ help="If set to True, enable prompt_tokens_details in usage.")
91
+
92
+ return parser.parse_args()
93
+
94
+
95
+ # explicitly use pure text format, with a newline at the end
96
+ # this makes it impossible to see the animation in the progress bar
97
+ # but will avoid messing up with ray or multiprocessing, which wraps
98
+ # each line of output with some prefix.
99
+ _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
100
+
101
+
102
+ class BatchProgressTracker:
103
+
104
+ def __init__(self):
105
+ self._total = 0
106
+ self._pbar: Optional[tqdm] = None
107
+
108
+ def submitted(self):
109
+ self._total += 1
110
+
111
+ def completed(self):
112
+ if self._pbar:
113
+ self._pbar.update()
114
+
115
+ def pbar(self) -> tqdm:
116
+ enable_tqdm = not torch.distributed.is_initialized(
117
+ ) or torch.distributed.get_rank() == 0
118
+ self._pbar = tqdm(total=self._total,
119
+ unit="req",
120
+ desc="Running batch",
121
+ mininterval=5,
122
+ disable=not enable_tqdm,
123
+ bar_format=_BAR_FORMAT)
124
+ return self._pbar
125
+
126
+
127
+ async def read_file(path_or_url: str) -> str:
128
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
129
+ async with aiohttp.ClientSession() as session, \
130
+ session.get(path_or_url) as resp:
131
+ return await resp.text()
132
+ else:
133
+ with open(path_or_url, encoding="utf-8") as f:
134
+ return f.read()
135
+
136
+
137
+ async def write_file(path_or_url: str, data: str) -> None:
138
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
139
+ async with aiohttp.ClientSession() as session, \
140
+ session.put(path_or_url, data=data.encode("utf-8")):
141
+ pass
142
+ else:
143
+ # We should make this async, but as long as this is always run as a
144
+ # standalone program, blocking the event loop won't effect performance
145
+ # in this particular case.
146
+ with open(path_or_url, "w", encoding="utf-8") as f:
147
+ f.write(data)
148
+
149
+
150
+ def make_error_request_output(request: BatchRequestInput,
151
+ error_msg: str) -> BatchRequestOutput:
152
+ batch_output = BatchRequestOutput(
153
+ id=f"vllm-{random_uuid()}",
154
+ custom_id=request.custom_id,
155
+ response=BatchResponseData(
156
+ status_code=HTTPStatus.BAD_REQUEST,
157
+ request_id=f"vllm-batch-{random_uuid()}",
158
+ ),
159
+ error=error_msg,
160
+ )
161
+ return batch_output
162
+
163
+
164
+ async def make_async_error_request_output(
165
+ request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
166
+ return make_error_request_output(request, error_msg)
167
+
168
+
169
+ async def run_request(serving_engine_func: Callable,
170
+ request: BatchRequestInput,
171
+ tracker: BatchProgressTracker) -> BatchRequestOutput:
172
+ response = await serving_engine_func(request.body)
173
+
174
+ if isinstance(response,
175
+ (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
176
+ batch_output = BatchRequestOutput(
177
+ id=f"vllm-{random_uuid()}",
178
+ custom_id=request.custom_id,
179
+ response=BatchResponseData(
180
+ body=response, request_id=f"vllm-batch-{random_uuid()}"),
181
+ error=None,
182
+ )
183
+ elif isinstance(response, ErrorResponse):
184
+ batch_output = BatchRequestOutput(
185
+ id=f"vllm-{random_uuid()}",
186
+ custom_id=request.custom_id,
187
+ response=BatchResponseData(
188
+ status_code=response.code,
189
+ request_id=f"vllm-batch-{random_uuid()}"),
190
+ error=response,
191
+ )
192
+ else:
193
+ batch_output = make_error_request_output(
194
+ request, error_msg="Request must not be sent in stream mode")
195
+
196
+ tracker.completed()
197
+ return batch_output
198
+
199
+
200
+ async def main(args):
201
+ if args.served_model_name is not None:
202
+ served_model_names = args.served_model_name
203
+ else:
204
+ served_model_names = [args.model]
205
+
206
+ engine_args = AsyncEngineArgs.from_cli_args(args)
207
+ engine = AsyncLLMEngine.from_engine_args(
208
+ engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
209
+
210
+ model_config = await engine.get_model_config()
211
+ base_model_paths = [
212
+ BaseModelPath(name=name, model_path=args.model)
213
+ for name in served_model_names
214
+ ]
215
+
216
+ if args.disable_log_requests:
217
+ request_logger = None
218
+ else:
219
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
220
+
221
+ # Create the openai serving objects.
222
+ openai_serving_models = OpenAIServingModels(
223
+ engine_client=engine,
224
+ model_config=model_config,
225
+ base_model_paths=base_model_paths,
226
+ lora_modules=None,
227
+ prompt_adapters=None,
228
+ )
229
+ openai_serving_chat = OpenAIServingChat(
230
+ engine,
231
+ model_config,
232
+ openai_serving_models,
233
+ args.response_role,
234
+ request_logger=request_logger,
235
+ chat_template=None,
236
+ chat_template_content_format="auto",
237
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
238
+ ) if model_config.runner_type == "generate" else None
239
+ openai_serving_embedding = OpenAIServingEmbedding(
240
+ engine,
241
+ model_config,
242
+ openai_serving_models,
243
+ request_logger=request_logger,
244
+ chat_template=None,
245
+ chat_template_content_format="auto",
246
+ ) if model_config.task == "embed" else None
247
+ openai_serving_scores = (OpenAIServingScores(
248
+ engine,
249
+ model_config,
250
+ openai_serving_models,
251
+ request_logger=request_logger,
252
+ ) if model_config.task == "score" else None)
253
+
254
+ tracker = BatchProgressTracker()
255
+ logger.info("Reading batch from %s...", args.input_file)
256
+
257
+ # Submit all requests in the file to the engine "concurrently".
258
+ response_futures: List[Awaitable[BatchRequestOutput]] = []
259
+ for request_json in (await read_file(args.input_file)).strip().split("\n"):
260
+ # Skip empty lines.
261
+ request_json = request_json.strip()
262
+ if not request_json:
263
+ continue
264
+
265
+ request = BatchRequestInput.model_validate_json(request_json)
266
+
267
+ # Determine the type of request and run it.
268
+ if request.url == "/v1/chat/completions":
269
+ handler_fn = (None if openai_serving_chat is None else
270
+ openai_serving_chat.create_chat_completion)
271
+ if handler_fn is None:
272
+ response_futures.append(
273
+ make_async_error_request_output(
274
+ request,
275
+ error_msg=
276
+ "The model does not support Chat Completions API",
277
+ ))
278
+ continue
279
+
280
+ response_futures.append(run_request(handler_fn, request, tracker))
281
+ tracker.submitted()
282
+ elif request.url == "/v1/embeddings":
283
+ handler_fn = (None if openai_serving_embedding is None else
284
+ openai_serving_embedding.create_embedding)
285
+ if handler_fn is None:
286
+ response_futures.append(
287
+ make_async_error_request_output(
288
+ request,
289
+ error_msg="The model does not support Embeddings API",
290
+ ))
291
+ continue
292
+
293
+ response_futures.append(run_request(handler_fn, request, tracker))
294
+ tracker.submitted()
295
+ elif request.url == "/v1/score":
296
+ handler_fn = (None if openai_serving_scores is None else
297
+ openai_serving_scores.create_score)
298
+ if handler_fn is None:
299
+ response_futures.append(
300
+ make_async_error_request_output(
301
+ request,
302
+ error_msg="The model does not support Scores API",
303
+ ))
304
+ continue
305
+
306
+ response_futures.append(run_request(handler_fn, request, tracker))
307
+ tracker.submitted()
308
+ else:
309
+ response_futures.append(
310
+ make_async_error_request_output(
311
+ request,
312
+ error_msg=
313
+ "Only /v1/chat/completions, /v1/embeddings, and /v1/score "
314
+ "are supported in the batch endpoint.",
315
+ ))
316
+
317
+ with tracker.pbar():
318
+ responses = await asyncio.gather(*response_futures)
319
+
320
+ output_buffer = StringIO()
321
+ for response in responses:
322
+ print(response.model_dump_json(), file=output_buffer)
323
+
324
+ output_buffer.seek(0)
325
+ await write_file(args.output_file, output_buffer.read().strip())
326
+
327
+
328
+ if __name__ == "__main__":
329
+ args = parse_args()
330
+
331
+ logger.info("vLLM batch processing API version %s", VLLM_VERSION)
332
+ logger.info("args: %s", args)
333
+
334
+ # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
335
+ # to publish metrics at the /metrics endpoint.
336
+ if args.enable_metrics:
337
+ logger.info("Prometheus metrics enabled")
338
+ start_http_server(port=args.port, addr=args.url)
339
+ else:
340
+ logger.info("Prometheus metrics disabled")
341
+
342
+ asyncio.run(main(args))
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_chat.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import json
5
+ import time
6
+ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
7
+ Optional)
8
+ from typing import Sequence as GenericSequence
9
+ from typing import Union
10
+
11
+ from fastapi import Request
12
+
13
+ from vllm.config import ModelConfig
14
+ from vllm.engine.protocol import EngineClient
15
+ from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
16
+ ConversationMessage)
17
+ from vllm.entrypoints.logger import RequestLogger
18
+ from vllm.entrypoints.openai.protocol import (
19
+ ChatCompletionLogProb, ChatCompletionLogProbs,
20
+ ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
21
+ ChatCompletionRequest, ChatCompletionResponse,
22
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
23
+ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
24
+ DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
25
+ RequestResponseMetadata, ToolCall, UsageInfo)
26
+ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
27
+ ReasoningParserManager)
28
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
29
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
30
+ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
31
+ from vllm.logger import init_logger
32
+ from vllm.outputs import CompletionOutput, RequestOutput
33
+ from vllm.sampling_params import BeamSearchParams, SamplingParams
34
+ from vllm.sequence import Logprob
35
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
36
+ from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
37
+
38
+ logger = init_logger(__name__)
39
+
40
+
41
+ class OpenAIServingChat(OpenAIServing):
42
+
43
+ def __init__(
44
+ self,
45
+ engine_client: EngineClient,
46
+ model_config: ModelConfig,
47
+ models: OpenAIServingModels,
48
+ response_role: str,
49
+ *,
50
+ request_logger: Optional[RequestLogger],
51
+ chat_template: Optional[str],
52
+ chat_template_content_format: ChatTemplateContentFormatOption,
53
+ return_tokens_as_token_ids: bool = False,
54
+ enable_reasoning: bool = False,
55
+ reasoning_parser: Optional[str] = None,
56
+ enable_auto_tools: bool = False,
57
+ tool_parser: Optional[str] = None,
58
+ enable_prompt_tokens_details: bool = False,
59
+ ) -> None:
60
+ super().__init__(engine_client=engine_client,
61
+ model_config=model_config,
62
+ models=models,
63
+ request_logger=request_logger,
64
+ return_tokens_as_token_ids=return_tokens_as_token_ids)
65
+
66
+ self.response_role = response_role
67
+ self.chat_template = chat_template
68
+ self.chat_template_content_format: Final = chat_template_content_format
69
+
70
+ # set up tool use
71
+ self.enable_auto_tools: bool = enable_auto_tools
72
+ if self.enable_auto_tools:
73
+ logger.info(
74
+ "\"auto\" tool choice has been enabled please note that while"
75
+ " the parallel_tool_calls client option is preset for "
76
+ "compatibility reasons, it will be ignored.")
77
+
78
+ self.enable_reasoning: bool = enable_reasoning
79
+ self.reasoning_parser: Optional[Callable[[AnyTokenizer],
80
+ ReasoningParser]] = None
81
+ if self.enable_reasoning:
82
+ try:
83
+ self.reasoning_parser = (
84
+ ReasoningParserManager.get_reasoning_parser(
85
+ reasoning_parser))
86
+ except Exception as e:
87
+ raise TypeError("Error: --enable-reasoning requires "
88
+ f"reasoning_parser:'{reasoning_parser}' "
89
+ "which has not been registered") from e
90
+ self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
91
+ if self.enable_auto_tools:
92
+ try:
93
+ if (tool_parser == "pythonic" and
94
+ model_config.model.startswith("meta-llama/Llama-3.2")):
95
+ logger.warning(
96
+ "Llama3.2 models may struggle to emit valid pythonic"
97
+ " tool calls")
98
+ self.tool_parser = ToolParserManager.get_tool_parser(
99
+ tool_parser)
100
+ except Exception as e:
101
+ raise TypeError("Error: --enable-auto-tool-choice requires "
102
+ f"tool_parser:'{tool_parser}' which has not "
103
+ "been registered") from e
104
+
105
+ self.enable_prompt_tokens_details = enable_prompt_tokens_details
106
+ diff_sampling_param = self.model_config.get_diff_sampling_param()
107
+ if diff_sampling_param:
108
+ logger.info("Overwriting default chat sampling param with: %s",
109
+ diff_sampling_param)
110
+
111
+ async def create_chat_completion(
112
+ self,
113
+ request: ChatCompletionRequest,
114
+ raw_request: Optional[Request] = None,
115
+ ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
116
+ ErrorResponse]:
117
+ """
118
+ Chat Completion API similar to OpenAI's API.
119
+
120
+ See https://platform.openai.com/docs/api-reference/chat/create
121
+ for the API specification. This API mimics the OpenAI
122
+ Chat Completion API.
123
+ """
124
+ error_check_ret = await self._check_model(request)
125
+ if error_check_ret is not None:
126
+ logger.error("Error with model %s", error_check_ret)
127
+ return error_check_ret
128
+
129
+ # If the engine is dead, raise the engine's DEAD_ERROR.
130
+ # This is required for the streaming case, where we return a
131
+ # success status before we actually start generating text :).
132
+ if self.engine_client.errored:
133
+ raise self.engine_client.dead_error
134
+
135
+ try:
136
+ (
137
+ lora_request,
138
+ prompt_adapter_request,
139
+ ) = self._maybe_get_adapters(request)
140
+
141
+ model_name = self.models.model_name(lora_request)
142
+
143
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
144
+
145
+ tool_parser = self.tool_parser
146
+
147
+ # validation for OpenAI tools
148
+ # tool_choice = "required" is not supported
149
+ if request.tool_choice == "required":
150
+ return self.create_error_response(
151
+ "tool_choice = \"required\" is not supported!")
152
+
153
+ # because of issues with pydantic we need to potentially
154
+ # re-serialize the tool_calls field of the request
155
+ # for more info: see comment in `maybe_serialize_tool_calls`
156
+ if isinstance(tokenizer, MistralTokenizer):
157
+ maybe_serialize_tool_calls(request)
158
+
159
+ if (request.tool_choice == "auto" and
160
+ not (self.enable_auto_tools and tool_parser is not None)
161
+ and not isinstance(tokenizer, MistralTokenizer)):
162
+ # for hf tokenizers, "auto" tools requires
163
+ # --enable-auto-tool-choice and --tool-call-parser
164
+ return self.create_error_response(
165
+ "\"auto\" tool choice requires "
166
+ "--enable-auto-tool-choice and --tool-call-parser to be set"
167
+ )
168
+
169
+ tool_dicts = None if request.tools is None else [
170
+ tool.model_dump() for tool in request.tools
171
+ ]
172
+
173
+ (
174
+ conversation,
175
+ request_prompts,
176
+ engine_prompts,
177
+ ) = await self._preprocess_chat(
178
+ request,
179
+ tokenizer,
180
+ request.messages,
181
+ chat_template=request.chat_template or self.chat_template,
182
+ chat_template_content_format=self.chat_template_content_format,
183
+ add_generation_prompt=request.add_generation_prompt,
184
+ continue_final_message=request.continue_final_message,
185
+ tool_dicts=tool_dicts,
186
+ documents=request.documents,
187
+ chat_template_kwargs=request.chat_template_kwargs,
188
+ tool_parser=tool_parser,
189
+ truncate_prompt_tokens=request.truncate_prompt_tokens,
190
+ add_special_tokens=request.add_special_tokens,
191
+ )
192
+ except ValueError as e:
193
+ logger.exception("Error in preprocessing prompt inputs")
194
+ return self.create_error_response(str(e))
195
+
196
+ request_id = "chatcmpl-" \
197
+ f"{self._base_request_id(raw_request, request.request_id)}"
198
+
199
+ request_metadata = RequestResponseMetadata(request_id=request_id)
200
+ if raw_request:
201
+ raw_request.state.request_metadata = request_metadata
202
+
203
+ # Schedule the request and get the result generator.
204
+ generators: List[AsyncGenerator[RequestOutput, None]] = []
205
+ try:
206
+ for i, engine_prompt in enumerate(engine_prompts):
207
+ sampling_params: Union[SamplingParams, BeamSearchParams]
208
+ default_max_tokens = self.max_model_len - len(
209
+ engine_prompt["prompt_token_ids"])
210
+ # Build default sampling params
211
+ default_sampling_params = (
212
+ self.model_config.get_diff_sampling_param())
213
+ if request.use_beam_search:
214
+ sampling_params = request.to_beam_search_params(
215
+ default_max_tokens, default_sampling_params)
216
+ else:
217
+ sampling_params = request.to_sampling_params(
218
+ default_max_tokens,
219
+ self.model_config.logits_processor_pattern,
220
+ default_sampling_params)
221
+
222
+ self._log_inputs(request_id,
223
+ request_prompts[i],
224
+ params=sampling_params,
225
+ lora_request=lora_request,
226
+ prompt_adapter_request=prompt_adapter_request)
227
+
228
+ trace_headers = (None if raw_request is None else await
229
+ self._get_trace_headers(raw_request.headers))
230
+
231
+ if isinstance(sampling_params, BeamSearchParams):
232
+ generator = self.engine_client.beam_search(
233
+ prompt=engine_prompt,
234
+ request_id=request_id,
235
+ params=sampling_params,
236
+ )
237
+ else:
238
+ generator = self.engine_client.generate(
239
+ engine_prompt,
240
+ sampling_params,
241
+ request_id,
242
+ lora_request=lora_request,
243
+ trace_headers=trace_headers,
244
+ prompt_adapter_request=prompt_adapter_request,
245
+ priority=request.priority,
246
+ )
247
+
248
+ generators.append(generator)
249
+ except ValueError as e:
250
+ # TODO: Use a vllm-specific Validation Error
251
+ return self.create_error_response(str(e))
252
+
253
+ assert len(generators) == 1
254
+ result_generator, = generators
255
+
256
+ # Streaming response
257
+ if request.stream:
258
+ return self.chat_completion_stream_generator(
259
+ request, result_generator, request_id, model_name,
260
+ conversation, tokenizer, request_metadata)
261
+
262
+ try:
263
+ return await self.chat_completion_full_generator(
264
+ request, result_generator, request_id, model_name,
265
+ conversation, tokenizer, request_metadata)
266
+ except ValueError as e:
267
+ # TODO: Use a vllm-specific Validation Error
268
+ return self.create_error_response(str(e))
269
+
270
+ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
271
+ if request.add_generation_prompt:
272
+ return self.response_role
273
+ return request.messages[-1]["role"]
274
+
275
+ async def chat_completion_stream_generator(
276
+ self,
277
+ request: ChatCompletionRequest,
278
+ result_generator: AsyncIterator[RequestOutput],
279
+ request_id: str,
280
+ model_name: str,
281
+ conversation: List[ConversationMessage],
282
+ tokenizer: AnyTokenizer,
283
+ request_metadata: RequestResponseMetadata,
284
+ ) -> AsyncGenerator[str, None]:
285
+ created_time = int(time.time())
286
+ chunk_object_type: Final = "chat.completion.chunk"
287
+ first_iteration = True
288
+
289
+ # Send response for each token for each request.n (index)
290
+ num_choices = 1 if request.n is None else request.n
291
+ previous_num_tokens = [0] * num_choices
292
+ finish_reason_sent = [False] * num_choices
293
+ num_prompt_tokens = 0
294
+ num_cached_tokens = None
295
+
296
+ if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
297
+ tool_choice_function_name = request.tool_choice.function.name
298
+ else:
299
+ tool_choice_function_name = None
300
+
301
+ # Determine whether tools are in use with "auto" tool choice
302
+ tool_choice_auto = (
303
+ not tool_choice_function_name
304
+ and self._should_stream_with_auto_tool_parsing(request))
305
+
306
+ should_stream_with_reasoning_parsing = (
307
+ self._should_stream_with_reasoning_parsing(request))
308
+
309
+ all_previous_token_ids: Optional[List[List[int]]]
310
+
311
+ # Only one of these will be used, thus previous_texts and
312
+ # all_previous_token_ids will not be used twice in the same iteration.
313
+ if tool_choice_auto or should_stream_with_reasoning_parsing:
314
+ # These are only required in "auto" tool choice case
315
+ previous_texts = [""] * num_choices
316
+ all_previous_token_ids = [[]] * num_choices
317
+ else:
318
+ previous_texts, all_previous_token_ids = None, None
319
+
320
+ try:
321
+ # There is no need to check if the reasoning_parser is None
322
+ # because the should_stream_with_reasoning_parsing check
323
+ # already ensures that the reasoning_parser is not None.
324
+ # but the pre-commit hook requires it.
325
+ if should_stream_with_reasoning_parsing and \
326
+ self.reasoning_parser is not None:
327
+ reasoning_parser = self.reasoning_parser(tokenizer)
328
+ except RuntimeError as e:
329
+ logger.exception("Error in reasoning parser creation.")
330
+ data = self.create_streaming_error_response(str(e))
331
+ yield f"data: {data}\n\n"
332
+ yield "data: [DONE]\n\n"
333
+ return
334
+
335
+ # Prepare the tool parser if it's needed
336
+ try:
337
+ if tool_choice_auto and self.tool_parser:
338
+ tool_parsers: List[Optional[ToolParser]] = [
339
+ self.tool_parser(tokenizer)
340
+ ] * num_choices
341
+ else:
342
+ tool_parsers = [None] * num_choices
343
+ except Exception as e:
344
+ logger.exception("Error in tool parser creation.")
345
+ data = self.create_streaming_error_response(str(e))
346
+ yield f"data: {data}\n\n"
347
+ yield "data: [DONE]\n\n"
348
+ return
349
+
350
+ stream_options = request.stream_options
351
+ if stream_options:
352
+ include_usage = stream_options.include_usage
353
+ include_continuous_usage = include_usage and \
354
+ stream_options.continuous_usage_stats
355
+ else:
356
+ include_usage, include_continuous_usage = False, False
357
+
358
+ try:
359
+ async for res in result_generator:
360
+ if res.prompt_token_ids is not None:
361
+ num_prompt_tokens = len(res.prompt_token_ids)
362
+ if res.encoder_prompt_token_ids is not None:
363
+ num_prompt_tokens += len(res.encoder_prompt_token_ids)
364
+
365
+ # We need to do it here, because if there are exceptions in
366
+ # the result_generator, it needs to be sent as the FIRST
367
+ # response (by the try...catch).
368
+ if first_iteration:
369
+ num_cached_tokens = res.num_cached_tokens
370
+ # Send first response for each request.n (index) with
371
+ # the role
372
+ role = self.get_chat_request_role(request)
373
+
374
+ # NOTE num_choices defaults to 1 so this usually executes
375
+ # once per request
376
+ for i in range(num_choices):
377
+ choice_data = ChatCompletionResponseStreamChoice(
378
+ index=i,
379
+ delta=DeltaMessage(
380
+ role=role,
381
+ content="",
382
+ ),
383
+ logprobs=None,
384
+ finish_reason=None)
385
+ chunk = ChatCompletionStreamResponse(
386
+ id=request_id,
387
+ object=chunk_object_type,
388
+ created=created_time,
389
+ choices=[choice_data],
390
+ model=model_name)
391
+
392
+ # if continuous usage stats are requested, add it
393
+ if include_continuous_usage:
394
+ chunk.usage = UsageInfo(
395
+ prompt_tokens=num_prompt_tokens,
396
+ completion_tokens=0,
397
+ total_tokens=num_prompt_tokens)
398
+
399
+ data = chunk.model_dump_json(exclude_unset=True)
400
+ yield f"data: {data}\n\n"
401
+
402
+ # Send response to echo the input portion of the
403
+ # last message
404
+ if request.echo:
405
+ last_msg_content: Union[str, List[Dict[str, str]]] = ""
406
+ if conversation and "content" in conversation[
407
+ -1] and conversation[-1].get("role") == role:
408
+ last_msg_content = conversation[-1]["content"] or ""
409
+
410
+ if last_msg_content:
411
+ for i in range(num_choices):
412
+ choice_data = (
413
+ ChatCompletionResponseStreamChoice(
414
+ index=i,
415
+ delta=DeltaMessage(
416
+ content=last_msg_content),
417
+ logprobs=None,
418
+ finish_reason=None))
419
+ chunk = ChatCompletionStreamResponse(
420
+ id=request_id,
421
+ object=chunk_object_type,
422
+ created=created_time,
423
+ choices=[choice_data],
424
+ model=model_name)
425
+ if include_continuous_usage:
426
+ chunk.usage = UsageInfo(
427
+ prompt_tokens=num_prompt_tokens,
428
+ completion_tokens=0,
429
+ total_tokens=num_prompt_tokens)
430
+
431
+ data = chunk.model_dump_json(
432
+ exclude_unset=True)
433
+ yield f"data: {data}\n\n"
434
+ first_iteration = False
435
+
436
+ for output in res.outputs:
437
+ i = output.index
438
+ tool_parser = tool_parsers[i]
439
+
440
+ if finish_reason_sent[i]:
441
+ continue
442
+
443
+ if request.logprobs and request.top_logprobs is not None:
444
+ assert output.logprobs is not None, (
445
+ "Did not output logprobs")
446
+ logprobs = self._create_chat_logprobs(
447
+ token_ids=output.token_ids,
448
+ top_logprobs=output.logprobs,
449
+ tokenizer=tokenizer,
450
+ num_output_top_logprobs=request.top_logprobs,
451
+ )
452
+ else:
453
+ logprobs = None
454
+
455
+ delta_text = output.text
456
+
457
+ if not delta_text and not output.token_ids and \
458
+ not previous_num_tokens[i]:
459
+ # Chunked prefill case, don't return empty chunks
460
+ continue
461
+
462
+ delta_message: Optional[DeltaMessage]
463
+
464
+ # handle streaming deltas for tools with named tool_choice
465
+ if tool_choice_function_name:
466
+ delta_message = DeltaMessage(tool_calls=[
467
+ DeltaToolCall(function=DeltaFunctionCall(
468
+ name=tool_choice_function_name,
469
+ arguments=delta_text),
470
+ index=i)
471
+ ])
472
+
473
+ # handle streaming deltas for tools with "auto" tool choice
474
+ elif tool_choice_auto:
475
+ assert previous_texts is not None
476
+ assert all_previous_token_ids is not None
477
+ assert tool_parser is not None
478
+ #TODO optimize manipulation of these lists
479
+ previous_text = previous_texts[i]
480
+ previous_token_ids = all_previous_token_ids[i]
481
+ current_text = previous_text + delta_text
482
+ current_token_ids = previous_token_ids + list(
483
+ output.token_ids)
484
+
485
+ delta_message = (
486
+ tool_parser.extract_tool_calls_streaming(
487
+ previous_text=previous_text,
488
+ current_text=current_text,
489
+ delta_text=delta_text,
490
+ previous_token_ids=previous_token_ids,
491
+ current_token_ids=current_token_ids,
492
+ delta_token_ids=output.token_ids,
493
+ request=request))
494
+
495
+ # update the previous values for the next iteration
496
+ previous_texts[i] = current_text
497
+ all_previous_token_ids[i] = current_token_ids
498
+ # reasoning_content cannot be enabled with tool_choice.
499
+ # If it is, the tool_choice will be used instead.
500
+ elif self.enable_reasoning:
501
+ # handle reasoning_content delta
502
+ assert reasoning_parser is not None
503
+ assert previous_texts is not None
504
+ assert all_previous_token_ids is not None
505
+ previous_text = previous_texts[i]
506
+ previous_token_ids = all_previous_token_ids[i]
507
+ current_text = previous_text + delta_text
508
+ current_token_ids = previous_token_ids + list(
509
+ output.token_ids)
510
+
511
+ delta_message = (reasoning_parser.
512
+ extract_reasoning_content_streaming(
513
+ previous_text,
514
+ current_text,
515
+ delta_text,
516
+ previous_token_ids,
517
+ current_token_ids,
518
+ output.token_ids,
519
+ ))
520
+
521
+ # update the previous values for the next iteration
522
+ previous_texts[i] = current_text
523
+ all_previous_token_ids[i] = current_token_ids
524
+
525
+ # handle streaming just a content delta
526
+ else:
527
+ delta_message = DeltaMessage(content=delta_text)
528
+
529
+ # set the previous values for the next iteration
530
+ previous_num_tokens[i] += len(output.token_ids)
531
+
532
+ # if the message delta is None (e.g. because it was a
533
+ # "control token" for tool calls or the parser otherwise
534
+ # wasn't ready to send a token, then
535
+ # get the next token without streaming a chunk
536
+ if delta_message is None:
537
+ continue
538
+
539
+ if output.finish_reason is None:
540
+ # Send token-by-token response for each request.n
541
+ choice_data = ChatCompletionResponseStreamChoice(
542
+ index=i,
543
+ delta=delta_message,
544
+ logprobs=logprobs,
545
+ finish_reason=None)
546
+
547
+ # if the model is finished generating
548
+ else:
549
+ # check to make sure we haven't "forgotten" to stream
550
+ # any tokens that were generated but previously
551
+ # matched by partial json parsing
552
+ # only happens if we are NOT using guided decoding
553
+ auto_tools_called = False
554
+ if tool_parser:
555
+ auto_tools_called = len(
556
+ tool_parser.prev_tool_call_arr) > 0
557
+ index = len(tool_parser.prev_tool_call_arr
558
+ ) - 1 if auto_tools_called else 0
559
+ else:
560
+ index = 0
561
+
562
+ if self._should_check_for_unstreamed_tool_arg_tokens(
563
+ delta_message, output) and tool_parser:
564
+ latest_delta_len = 0
565
+ if ((isinstance(
566
+ delta_message.tool_calls[0].function,
567
+ DeltaFunctionCall)) and isinstance(
568
+ delta_message.tool_calls[0].function.
569
+ arguments, str)):
570
+ latest_delta_len = len(
571
+ delta_message.tool_calls[0].function.
572
+ arguments)
573
+
574
+ # get the expected call based on partial JSON
575
+ # parsing which "autocompletes" the JSON
576
+ expected_call = json.dumps(
577
+ tool_parser.prev_tool_call_arr[index].get(
578
+ "arguments", {}),
579
+ ensure_ascii=False)
580
+
581
+ # get what we've streamed so far for arguments
582
+ # for the current tool
583
+ actual_call = tool_parser.streamed_args_for_tool[
584
+ index]
585
+ if (latest_delta_len > 0):
586
+ actual_call = actual_call[:-latest_delta_len]
587
+
588
+ # check to see if there's anything left to stream
589
+ remaining_call = expected_call.replace(
590
+ actual_call, "", 1)
591
+ # set that as a delta message
592
+ delta_message = DeltaMessage(tool_calls=[
593
+ DeltaToolCall(index=index,
594
+ function=DeltaFunctionCall(
595
+ arguments=remaining_call).
596
+ model_dump(exclude_none=True))
597
+ ])
598
+
599
+ # Send the finish response for each request.n only once
600
+ choice_data = ChatCompletionResponseStreamChoice(
601
+ index=i,
602
+ delta=delta_message,
603
+ logprobs=logprobs,
604
+ finish_reason=output.finish_reason
605
+ if not auto_tools_called else "tool_calls",
606
+ stop_reason=output.stop_reason)
607
+
608
+ finish_reason_sent[i] = True
609
+
610
+ chunk = ChatCompletionStreamResponse(
611
+ id=request_id,
612
+ object=chunk_object_type,
613
+ created=created_time,
614
+ choices=[choice_data],
615
+ model=model_name)
616
+
617
+ # handle usage stats if requested & if continuous
618
+ if include_continuous_usage:
619
+ completion_tokens = previous_num_tokens[i]
620
+ chunk.usage = UsageInfo(
621
+ prompt_tokens=num_prompt_tokens,
622
+ completion_tokens=completion_tokens,
623
+ total_tokens=num_prompt_tokens + completion_tokens,
624
+ )
625
+
626
+ data = chunk.model_dump_json(exclude_unset=True)
627
+ yield f"data: {data}\n\n"
628
+
629
+ # once the final token is handled, if stream_options.include_usage
630
+ # is sent, send the usage
631
+ if include_usage:
632
+ completion_tokens = sum(previous_num_tokens)
633
+ final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
634
+ completion_tokens=completion_tokens,
635
+ total_tokens=num_prompt_tokens +
636
+ completion_tokens)
637
+ if self.enable_prompt_tokens_details and num_cached_tokens:
638
+ final_usage.prompt_tokens_details = PromptTokenUsageInfo(
639
+ cached_tokens=num_cached_tokens)
640
+
641
+ final_usage_chunk = ChatCompletionStreamResponse(
642
+ id=request_id,
643
+ object=chunk_object_type,
644
+ created=created_time,
645
+ choices=[],
646
+ model=model_name,
647
+ usage=final_usage)
648
+ final_usage_data = (final_usage_chunk.model_dump_json(
649
+ exclude_unset=True, exclude_none=True))
650
+ yield f"data: {final_usage_data}\n\n"
651
+
652
+ # report to FastAPI middleware aggregate usage across all choices
653
+ num_completion_tokens = sum(previous_num_tokens)
654
+ request_metadata.final_usage_info = UsageInfo(
655
+ prompt_tokens=num_prompt_tokens,
656
+ completion_tokens=num_completion_tokens,
657
+ total_tokens=num_prompt_tokens + num_completion_tokens)
658
+
659
+ except Exception as e:
660
+ # TODO: Use a vllm-specific Validation Error
661
+ logger.exception("Error in chat completion stream generator.")
662
+ data = self.create_streaming_error_response(str(e))
663
+ yield f"data: {data}\n\n"
664
+ # Send the final done message after all response.n are finished
665
+ yield "data: [DONE]\n\n"
666
+
667
+ async def chat_completion_full_generator(
668
+ self,
669
+ request: ChatCompletionRequest,
670
+ result_generator: AsyncIterator[RequestOutput],
671
+ request_id: str,
672
+ model_name: str,
673
+ conversation: List[ConversationMessage],
674
+ tokenizer: AnyTokenizer,
675
+ request_metadata: RequestResponseMetadata,
676
+ ) -> Union[ErrorResponse, ChatCompletionResponse]:
677
+
678
+ created_time = int(time.time())
679
+ final_res: Optional[RequestOutput] = None
680
+
681
+ try:
682
+ async for res in result_generator:
683
+ final_res = res
684
+ except asyncio.CancelledError:
685
+ return self.create_error_response("Client disconnected")
686
+ except ValueError as e:
687
+ # TODO: Use a vllm-specific Validation Error
688
+ return self.create_error_response(str(e))
689
+
690
+ assert final_res is not None
691
+
692
+ choices: List[ChatCompletionResponseChoice] = []
693
+
694
+ role = self.get_chat_request_role(request)
695
+ for output in final_res.outputs:
696
+ token_ids = output.token_ids
697
+ out_logprobs = output.logprobs
698
+
699
+ if request.logprobs and request.top_logprobs is not None:
700
+ assert out_logprobs is not None, "Did not output logprobs"
701
+ logprobs = self._create_chat_logprobs(
702
+ token_ids=token_ids,
703
+ top_logprobs=out_logprobs,
704
+ num_output_top_logprobs=request.top_logprobs,
705
+ tokenizer=tokenizer,
706
+ )
707
+ else:
708
+ logprobs = None
709
+
710
+ should_stream_with_reasoning_parsing = (
711
+ self._should_stream_with_reasoning_parsing(request))
712
+
713
+ # In the OpenAI API the finish_reason is "tools_called"
714
+ # if the tool choice is auto and the model produced a tool
715
+ # call. The same is not true for named function calls
716
+ auto_tools_called = False
717
+
718
+ if should_stream_with_reasoning_parsing and \
719
+ self.reasoning_parser is not None:
720
+ try:
721
+ reasoning_parser = self.reasoning_parser(tokenizer)
722
+ except RuntimeError as e:
723
+ logger.exception("Error in reasoning parser creation.")
724
+ return self.create_error_response(str(e))
725
+
726
+ reasoning_content, content = (
727
+ reasoning_parser.extract_reasoning_content(
728
+ output.text, request=request))
729
+
730
+ if reasoning_content:
731
+ message = ChatMessage(role=role,
732
+ content=content,
733
+ reasoning_content=reasoning_content)
734
+ else:
735
+ message = ChatMessage(role=role, content=output.text)
736
+
737
+ # if auto tools are not enabled, and a named tool choice using
738
+ # outlines is not being used
739
+ elif (not self.enable_auto_tools
740
+ or not self.tool_parser) and not isinstance(
741
+ request.tool_choice, ChatCompletionNamedToolChoiceParam):
742
+ message = ChatMessage(role=role, content=output.text)
743
+
744
+ # if the request uses tools and specified a tool choice
745
+ elif request.tool_choice and type(
746
+ request.tool_choice) is ChatCompletionNamedToolChoiceParam:
747
+
748
+ message = ChatMessage(
749
+ role=role,
750
+ content="",
751
+ tool_calls=[
752
+ ToolCall(function=FunctionCall(
753
+ name=request.tool_choice.function.name,
754
+ arguments=output.text))
755
+ ])
756
+
757
+ # if the request doesn't use tool choice
758
+ # OR specifies to not use a tool
759
+ elif not request.tool_choice or request.tool_choice == "none":
760
+
761
+ message = ChatMessage(role=role, content=output.text)
762
+
763
+ # handle when there are tools and tool choice is auto
764
+ elif request.tools and (
765
+ request.tool_choice == "auto"
766
+ or request.tool_choice is None) and self.enable_auto_tools \
767
+ and self.tool_parser:
768
+
769
+ try:
770
+ tool_parser = self.tool_parser(tokenizer)
771
+ except RuntimeError as e:
772
+ logger.exception("Error in tool parser creation.")
773
+ return self.create_error_response(str(e))
774
+
775
+ tool_call_info = tool_parser.extract_tool_calls(
776
+ output.text, request=request)
777
+ # In the OpenAI API the finish_reason is "tools_called"
778
+ # if the tool choice is auto and the model produced a tool
779
+ # call. The same is not true for named function calls
780
+ auto_tools_called = tool_call_info.tools_called
781
+ if tool_call_info.tools_called:
782
+ message = ChatMessage(role=role,
783
+ content=tool_call_info.content,
784
+ tool_calls=tool_call_info.tool_calls)
785
+
786
+ else:
787
+ # FOR NOW make it a chat message; we will have to detect
788
+ # the type to make it later.
789
+ message = ChatMessage(role=role, content=output.text)
790
+
791
+ # undetermined case that is still important to handle
792
+ else:
793
+ logger.error(
794
+ "Error in chat_completion_full_generator - cannot determine"
795
+ " if tools should be extracted. Returning a standard chat "
796
+ "completion.")
797
+ message = ChatMessage(role=role, content=output.text)
798
+
799
+ choice_data = ChatCompletionResponseChoice(
800
+ index=output.index,
801
+ message=message,
802
+ logprobs=logprobs,
803
+ finish_reason="tool_calls" if auto_tools_called else
804
+ output.finish_reason if output.finish_reason else "stop",
805
+ stop_reason=output.stop_reason)
806
+ choices.append(choice_data)
807
+
808
+ if request.echo:
809
+ last_msg_content: Union[str, List[Dict[str, str]]] = ""
810
+ if conversation and "content" in conversation[-1] and conversation[
811
+ -1].get("role") == role:
812
+ last_msg_content = conversation[-1]["content"] or ""
813
+ if isinstance(last_msg_content, list):
814
+ last_msg_content = "\n".join(msg['text']
815
+ for msg in last_msg_content)
816
+
817
+ for choice in choices:
818
+ full_message = last_msg_content + (choice.message.content
819
+ or "")
820
+ choice.message.content = full_message
821
+
822
+ assert final_res.prompt_token_ids is not None
823
+ num_prompt_tokens = len(final_res.prompt_token_ids)
824
+ if final_res.encoder_prompt_token_ids is not None:
825
+ num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
826
+ num_generated_tokens = sum(
827
+ len(output.token_ids) for output in final_res.outputs)
828
+ usage = UsageInfo(prompt_tokens=num_prompt_tokens,
829
+ completion_tokens=num_generated_tokens,
830
+ total_tokens=num_prompt_tokens +
831
+ num_generated_tokens)
832
+ if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
833
+ usage.prompt_tokens_details = PromptTokenUsageInfo(
834
+ cached_tokens=final_res.num_cached_tokens)
835
+
836
+ request_metadata.final_usage_info = usage
837
+
838
+ response = ChatCompletionResponse(
839
+ id=request_id,
840
+ created=created_time,
841
+ model=model_name,
842
+ choices=choices,
843
+ usage=usage,
844
+ prompt_logprobs=final_res.prompt_logprobs,
845
+ )
846
+
847
+ return response
848
+
849
+ def _get_top_logprobs(
850
+ self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
851
+ tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
852
+ return [
853
+ ChatCompletionLogProb(token=(token := self._get_decoded_token(
854
+ p[1],
855
+ p[0],
856
+ tokenizer,
857
+ return_as_token_id=self.return_tokens_as_token_ids)),
858
+ logprob=max(p[1].logprob, -9999.0),
859
+ bytes=list(
860
+ token.encode("utf-8", errors="replace")))
861
+ for i, p in enumerate(logprobs.items())
862
+ if top_logprobs and i < top_logprobs
863
+ ]
864
+
865
+ def _create_chat_logprobs(
866
+ self,
867
+ token_ids: GenericSequence[int],
868
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
869
+ tokenizer: AnyTokenizer,
870
+ num_output_top_logprobs: Optional[int] = None,
871
+ ) -> ChatCompletionLogProbs:
872
+ """Create OpenAI-style logprobs."""
873
+ logprobs_content: List[ChatCompletionLogProbsContent] = []
874
+
875
+ for i, token_id in enumerate(token_ids):
876
+ step_top_logprobs = top_logprobs[i]
877
+ if step_top_logprobs is None:
878
+ token = tokenizer.decode(token_id)
879
+ if self.return_tokens_as_token_ids:
880
+ token = f"token_id:{token_id}"
881
+
882
+ logprobs_content.append(
883
+ ChatCompletionLogProbsContent(
884
+ token=token,
885
+ bytes=list(token.encode("utf-8", errors="replace")),
886
+ ))
887
+ else:
888
+ step_token = step_top_logprobs[token_id]
889
+ step_decoded = step_token.decoded_token
890
+
891
+ logprobs_content.append(
892
+ ChatCompletionLogProbsContent(
893
+ token=self._get_decoded_token(
894
+ step_token,
895
+ token_id,
896
+ tokenizer,
897
+ self.return_tokens_as_token_ids,
898
+ ),
899
+ logprob=max(step_token.logprob, -9999.0),
900
+ bytes=None if step_decoded is None else list(
901
+ step_decoded.encode("utf-8", errors="replace")),
902
+ top_logprobs=self._get_top_logprobs(
903
+ step_top_logprobs,
904
+ num_output_top_logprobs,
905
+ tokenizer,
906
+ ),
907
+ ))
908
+
909
+ return ChatCompletionLogProbs(content=logprobs_content)
910
+
911
+ def _should_stream_with_auto_tool_parsing(self,
912
+ request: ChatCompletionRequest):
913
+ """
914
+ Utility function to check if streamed tokens should go through the tool
915
+ call parser that was configured.
916
+
917
+ We only want to do this IF user-provided tools are set, a tool parser
918
+ is configured, "auto" tool choice is enabled, and the request's tool
919
+ choice field indicates that "auto" tool choice should be used.
920
+ """
921
+ return (request.tools and self.tool_parser and self.enable_auto_tools
922
+ and request.tool_choice in ['auto', None])
923
+
924
+ def _should_stream_with_reasoning_parsing(self,
925
+ request: ChatCompletionRequest):
926
+ """
927
+ Utility function to check if streamed tokens should go through the
928
+ reasoning parser that was configured.
929
+
930
+ We only want to do this IF reasoning is enabled and a reasoning
931
+ parser is configured.
932
+ """
933
+ return self.enable_reasoning and self.reasoning_parser is not None
934
+
935
+ def _should_check_for_unstreamed_tool_arg_tokens(
936
+ self,
937
+ delta_message: Optional[DeltaMessage],
938
+ output: CompletionOutput,
939
+ ) -> bool:
940
+ """
941
+ Check to see if we should check for unstreamed tool arguments tokens.
942
+ This is only applicable when auto tool parsing is enabled, the delta
943
+ is a tool call with arguments.
944
+ """
945
+
946
+ # yapf: disable
947
+ return bool(
948
+ # if there is a delta message that includes tool calls which
949
+ # include a function that has arguments
950
+ output.finish_reason is not None
951
+ and self.enable_auto_tools and self.tool_parser and delta_message
952
+ and delta_message.tool_calls and delta_message.tool_calls[0]
953
+ and delta_message.tool_calls[0].function
954
+ and delta_message.tool_calls[0].function.arguments is not None
955
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_completion.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import time
5
+ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
6
+ from typing import Sequence as GenericSequence
7
+ from typing import Tuple, Union, cast
8
+
9
+ from fastapi import Request
10
+
11
+ from vllm.config import ModelConfig
12
+ from vllm.engine.protocol import EngineClient
13
+ from vllm.entrypoints.logger import RequestLogger
14
+ # yapf conflicts with isort for this block
15
+ # yapf: disable
16
+ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
17
+ CompletionRequest,
18
+ CompletionResponse,
19
+ CompletionResponseChoice,
20
+ CompletionResponseStreamChoice,
21
+ CompletionStreamResponse,
22
+ ErrorResponse,
23
+ RequestResponseMetadata,
24
+ UsageInfo)
25
+ # yapf: enable
26
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
27
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
28
+ from vllm.logger import init_logger
29
+ from vllm.outputs import RequestOutput
30
+ from vllm.sampling_params import BeamSearchParams, SamplingParams
31
+ from vllm.sequence import Logprob
32
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
33
+ from vllm.utils import merge_async_iterators
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class OpenAIServingCompletion(OpenAIServing):
39
+
40
+ def __init__(
41
+ self,
42
+ engine_client: EngineClient,
43
+ model_config: ModelConfig,
44
+ models: OpenAIServingModels,
45
+ *,
46
+ request_logger: Optional[RequestLogger],
47
+ return_tokens_as_token_ids: bool = False,
48
+ ):
49
+ super().__init__(engine_client=engine_client,
50
+ model_config=model_config,
51
+ models=models,
52
+ request_logger=request_logger,
53
+ return_tokens_as_token_ids=return_tokens_as_token_ids)
54
+ diff_sampling_param = self.model_config.get_diff_sampling_param()
55
+ if diff_sampling_param:
56
+ logger.info(
57
+ "Overwriting default completion sampling param with: %s",
58
+ diff_sampling_param)
59
+
60
+ async def create_completion(
61
+ self,
62
+ request: CompletionRequest,
63
+ raw_request: Optional[Request] = None,
64
+ ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
65
+ """Completion API similar to OpenAI's API.
66
+
67
+ See https://platform.openai.com/docs/api-reference/completions/create
68
+ for the API specification. This API mimics the OpenAI Completion API.
69
+
70
+ NOTE: Currently we do not support the following feature:
71
+ - suffix (the language models we currently support do not support
72
+ suffix)
73
+ """
74
+ error_check_ret = await self._check_model(request)
75
+ if error_check_ret is not None:
76
+ return error_check_ret
77
+
78
+ # If the engine is dead, raise the engine's DEAD_ERROR.
79
+ # This is required for the streaming case, where we return a
80
+ # success status before we actually start generating text :).
81
+ if self.engine_client.errored:
82
+ raise self.engine_client.dead_error
83
+
84
+ # Return error for unsupported features.
85
+ if request.suffix is not None:
86
+ return self.create_error_response(
87
+ "suffix is not currently supported")
88
+
89
+ request_id = f"cmpl-{self._base_request_id(raw_request)}"
90
+ created_time = int(time.time())
91
+
92
+ request_metadata = RequestResponseMetadata(request_id=request_id)
93
+ if raw_request:
94
+ raw_request.state.request_metadata = request_metadata
95
+
96
+ try:
97
+ (
98
+ lora_request,
99
+ prompt_adapter_request,
100
+ ) = self._maybe_get_adapters(request)
101
+
102
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
103
+
104
+ request_prompts, engine_prompts = await self._preprocess_completion(
105
+ request,
106
+ tokenizer,
107
+ request.prompt,
108
+ truncate_prompt_tokens=request.truncate_prompt_tokens,
109
+ add_special_tokens=request.add_special_tokens,
110
+ )
111
+ except ValueError as e:
112
+ logger.exception("Error in preprocessing prompt inputs")
113
+ return self.create_error_response(str(e))
114
+
115
+ # Schedule the request and get the result generator.
116
+ generators: List[AsyncGenerator[RequestOutput, None]] = []
117
+ try:
118
+ for i, engine_prompt in enumerate(engine_prompts):
119
+ sampling_params: Union[SamplingParams, BeamSearchParams]
120
+ default_max_tokens = self.max_model_len - len(
121
+ engine_prompt["prompt_token_ids"])
122
+ # Build default sampling params
123
+ default_sampling_params = (
124
+ self.model_config.get_diff_sampling_param())
125
+ if request.use_beam_search:
126
+ sampling_params = request.to_beam_search_params(
127
+ default_max_tokens, default_sampling_params)
128
+ else:
129
+ sampling_params = request.to_sampling_params(
130
+ default_max_tokens,
131
+ self.model_config.logits_processor_pattern,
132
+ default_sampling_params)
133
+
134
+ request_id_item = f"{request_id}-{i}"
135
+
136
+ self._log_inputs(request_id_item,
137
+ request_prompts[i],
138
+ params=sampling_params,
139
+ lora_request=lora_request,
140
+ prompt_adapter_request=prompt_adapter_request)
141
+
142
+ trace_headers = (None if raw_request is None else await
143
+ self._get_trace_headers(raw_request.headers))
144
+
145
+ if isinstance(sampling_params, BeamSearchParams):
146
+ generator = self.engine_client.beam_search(
147
+ prompt=engine_prompt,
148
+ request_id=request_id,
149
+ params=sampling_params,
150
+ )
151
+ else:
152
+ generator = self.engine_client.generate(
153
+ engine_prompt,
154
+ sampling_params,
155
+ request_id_item,
156
+ lora_request=lora_request,
157
+ prompt_adapter_request=prompt_adapter_request,
158
+ trace_headers=trace_headers,
159
+ priority=request.priority,
160
+ )
161
+
162
+ generators.append(generator)
163
+ except ValueError as e:
164
+ # TODO: Use a vllm-specific Validation Error
165
+ return self.create_error_response(str(e))
166
+
167
+ result_generator = merge_async_iterators(*generators)
168
+
169
+ model_name = self.models.model_name(lora_request)
170
+ num_prompts = len(engine_prompts)
171
+
172
+ # Similar to the OpenAI API, when n != best_of, we do not stream the
173
+ # results. In addition, we do not stream the results when use
174
+ # beam search.
175
+ stream = (request.stream
176
+ and (request.best_of is None or request.n == request.best_of)
177
+ and not request.use_beam_search)
178
+
179
+ # Streaming response
180
+ if stream:
181
+ return self.completion_stream_generator(
182
+ request,
183
+ result_generator,
184
+ request_id,
185
+ created_time,
186
+ model_name,
187
+ num_prompts=num_prompts,
188
+ tokenizer=tokenizer,
189
+ request_metadata=request_metadata)
190
+
191
+ # Non-streaming response
192
+ final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
193
+ try:
194
+ async for i, res in result_generator:
195
+ final_res_batch[i] = res
196
+
197
+ for i, final_res in enumerate(final_res_batch):
198
+ assert final_res is not None
199
+
200
+ # The output should contain the input text
201
+ # We did not pass it into vLLM engine to avoid being redundant
202
+ # with the inputs token IDs
203
+ if final_res.prompt is None:
204
+ final_res.prompt = request_prompts[i]["prompt"]
205
+
206
+ final_res_batch_checked = cast(List[RequestOutput],
207
+ final_res_batch)
208
+
209
+ response = self.request_output_to_completion_response(
210
+ final_res_batch_checked,
211
+ request,
212
+ request_id,
213
+ created_time,
214
+ model_name,
215
+ tokenizer,
216
+ request_metadata,
217
+ )
218
+ except asyncio.CancelledError:
219
+ return self.create_error_response("Client disconnected")
220
+ except ValueError as e:
221
+ # TODO: Use a vllm-specific Validation Error
222
+ return self.create_error_response(str(e))
223
+
224
+ # When user requests streaming but we don't stream, we still need to
225
+ # return a streaming response with a single event.
226
+ if request.stream:
227
+ response_json = response.model_dump_json()
228
+
229
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
230
+ yield f"data: {response_json}\n\n"
231
+ yield "data: [DONE]\n\n"
232
+
233
+ return fake_stream_generator()
234
+
235
+ return response
236
+
237
+ async def completion_stream_generator(
238
+ self,
239
+ request: CompletionRequest,
240
+ result_generator: AsyncIterator[Tuple[int, RequestOutput]],
241
+ request_id: str,
242
+ created_time: int,
243
+ model_name: str,
244
+ num_prompts: int,
245
+ tokenizer: AnyTokenizer,
246
+ request_metadata: RequestResponseMetadata,
247
+ ) -> AsyncGenerator[str, None]:
248
+ num_choices = 1 if request.n is None else request.n
249
+ previous_text_lens = [0] * num_choices * num_prompts
250
+ previous_num_tokens = [0] * num_choices * num_prompts
251
+ has_echoed = [False] * num_choices * num_prompts
252
+ num_prompt_tokens = [0] * num_prompts
253
+
254
+ stream_options = request.stream_options
255
+ if stream_options:
256
+ include_usage = stream_options.include_usage
257
+ include_continuous_usage = include_usage and \
258
+ stream_options.continuous_usage_stats
259
+ else:
260
+ include_usage, include_continuous_usage = False, False
261
+
262
+ try:
263
+ async for prompt_idx, res in result_generator:
264
+ prompt_token_ids = res.prompt_token_ids
265
+ prompt_logprobs = res.prompt_logprobs
266
+ prompt_text = res.prompt
267
+
268
+ # Prompt details are excluded from later streamed outputs
269
+ if res.prompt_token_ids is not None:
270
+ num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
271
+
272
+ delta_token_ids: GenericSequence[int]
273
+ out_logprobs: Optional[GenericSequence[Optional[Dict[
274
+ int, Logprob]]]]
275
+
276
+ for output in res.outputs:
277
+ i = output.index + prompt_idx * num_choices
278
+
279
+ assert request.max_tokens is not None
280
+ if request.echo and not has_echoed[i]:
281
+ assert prompt_token_ids is not None
282
+ assert prompt_text is not None
283
+ if request.max_tokens == 0:
284
+ # only return the prompt
285
+ delta_text = prompt_text
286
+ delta_token_ids = prompt_token_ids
287
+ out_logprobs = prompt_logprobs
288
+ else:
289
+ assert prompt_logprobs is not None
290
+ # echo the prompt and first token
291
+ delta_text = prompt_text + output.text
292
+ delta_token_ids = [
293
+ *prompt_token_ids, *output.token_ids
294
+ ]
295
+ out_logprobs = [
296
+ *prompt_logprobs,
297
+ *(output.logprobs or []),
298
+ ]
299
+ has_echoed[i] = True
300
+ else:
301
+ # return just the delta
302
+ delta_text = output.text
303
+ delta_token_ids = output.token_ids
304
+ out_logprobs = output.logprobs
305
+
306
+ if not delta_text and not delta_token_ids \
307
+ and not previous_num_tokens[i]:
308
+ # Chunked prefill case, don't return empty chunks
309
+ continue
310
+
311
+ if request.logprobs is not None:
312
+ assert out_logprobs is not None, (
313
+ "Did not output logprobs")
314
+ logprobs = self._create_completion_logprobs(
315
+ token_ids=delta_token_ids,
316
+ top_logprobs=out_logprobs,
317
+ num_output_top_logprobs=request.logprobs,
318
+ tokenizer=tokenizer,
319
+ initial_text_offset=previous_text_lens[i],
320
+ )
321
+ else:
322
+ logprobs = None
323
+
324
+ previous_text_lens[i] += len(output.text)
325
+ previous_num_tokens[i] += len(output.token_ids)
326
+ finish_reason = output.finish_reason
327
+ stop_reason = output.stop_reason
328
+
329
+ chunk = CompletionStreamResponse(
330
+ id=request_id,
331
+ created=created_time,
332
+ model=model_name,
333
+ choices=[
334
+ CompletionResponseStreamChoice(
335
+ index=i,
336
+ text=delta_text,
337
+ logprobs=logprobs,
338
+ finish_reason=finish_reason,
339
+ stop_reason=stop_reason,
340
+ )
341
+ ])
342
+ if include_continuous_usage:
343
+ prompt_tokens = num_prompt_tokens[prompt_idx]
344
+ completion_tokens = previous_num_tokens[i]
345
+ chunk.usage = UsageInfo(
346
+ prompt_tokens=prompt_tokens,
347
+ completion_tokens=completion_tokens,
348
+ total_tokens=prompt_tokens + completion_tokens,
349
+ )
350
+
351
+ response_json = chunk.model_dump_json(exclude_unset=False)
352
+ yield f"data: {response_json}\n\n"
353
+
354
+ total_prompt_tokens = sum(num_prompt_tokens)
355
+ total_completion_tokens = sum(previous_num_tokens)
356
+ final_usage_info = UsageInfo(
357
+ prompt_tokens=total_prompt_tokens,
358
+ completion_tokens=total_completion_tokens,
359
+ total_tokens=total_prompt_tokens + total_completion_tokens)
360
+
361
+ if include_usage:
362
+ final_usage_chunk = CompletionStreamResponse(
363
+ id=request_id,
364
+ created=created_time,
365
+ model=model_name,
366
+ choices=[],
367
+ usage=final_usage_info,
368
+ )
369
+ final_usage_data = (final_usage_chunk.model_dump_json(
370
+ exclude_unset=False, exclude_none=True))
371
+ yield f"data: {final_usage_data}\n\n"
372
+
373
+ # report to FastAPI middleware aggregate usage across all choices
374
+ request_metadata.final_usage_info = final_usage_info
375
+
376
+ except Exception as e:
377
+ # TODO: Use a vllm-specific Validation Error
378
+ data = self.create_streaming_error_response(str(e))
379
+ yield f"data: {data}\n\n"
380
+ yield "data: [DONE]\n\n"
381
+
382
+ def request_output_to_completion_response(
383
+ self,
384
+ final_res_batch: List[RequestOutput],
385
+ request: CompletionRequest,
386
+ request_id: str,
387
+ created_time: int,
388
+ model_name: str,
389
+ tokenizer: AnyTokenizer,
390
+ request_metadata: RequestResponseMetadata,
391
+ ) -> CompletionResponse:
392
+ choices: List[CompletionResponseChoice] = []
393
+ num_prompt_tokens = 0
394
+ num_generated_tokens = 0
395
+
396
+ for final_res in final_res_batch:
397
+ prompt_token_ids = final_res.prompt_token_ids
398
+ assert prompt_token_ids is not None
399
+ prompt_logprobs = final_res.prompt_logprobs
400
+ if prompt_logprobs:
401
+ for logprob_dict in prompt_logprobs:
402
+ if logprob_dict:
403
+ for logprob_values in logprob_dict.values():
404
+ if logprob_values.logprob == float('-inf'):
405
+ logprob_values.logprob = -9999.0
406
+ prompt_text = final_res.prompt
407
+
408
+ token_ids: GenericSequence[int]
409
+ out_logprobs: Optional[GenericSequence[Optional[Dict[int,
410
+ Logprob]]]]
411
+
412
+ for output in final_res.outputs:
413
+ assert request.max_tokens is not None
414
+ if request.echo:
415
+ assert prompt_text is not None
416
+ if request.max_tokens == 0:
417
+ token_ids = prompt_token_ids
418
+ out_logprobs = prompt_logprobs
419
+ output_text = prompt_text
420
+ else:
421
+ token_ids = [*prompt_token_ids, *output.token_ids]
422
+
423
+ if request.logprobs is None:
424
+ out_logprobs = None
425
+ else:
426
+ assert prompt_logprobs is not None
427
+ assert output.logprobs is not None
428
+ out_logprobs = [
429
+ *prompt_logprobs,
430
+ *output.logprobs,
431
+ ]
432
+
433
+ output_text = prompt_text + output.text
434
+ else:
435
+ token_ids = output.token_ids
436
+ out_logprobs = output.logprobs
437
+ output_text = output.text
438
+
439
+ if request.logprobs is not None:
440
+ assert out_logprobs is not None, "Did not output logprobs"
441
+ logprobs = self._create_completion_logprobs(
442
+ token_ids=token_ids,
443
+ top_logprobs=out_logprobs,
444
+ tokenizer=tokenizer,
445
+ num_output_top_logprobs=request.logprobs,
446
+ )
447
+ else:
448
+ logprobs = None
449
+
450
+ choice_data = CompletionResponseChoice(
451
+ index=len(choices),
452
+ text=output_text,
453
+ logprobs=logprobs,
454
+ finish_reason=output.finish_reason,
455
+ stop_reason=output.stop_reason,
456
+ prompt_logprobs=final_res.prompt_logprobs,
457
+ )
458
+ choices.append(choice_data)
459
+
460
+ num_generated_tokens += len(output.token_ids)
461
+
462
+ num_prompt_tokens += len(prompt_token_ids)
463
+
464
+ usage = UsageInfo(
465
+ prompt_tokens=num_prompt_tokens,
466
+ completion_tokens=num_generated_tokens,
467
+ total_tokens=num_prompt_tokens + num_generated_tokens,
468
+ )
469
+
470
+ request_metadata.final_usage_info = usage
471
+
472
+ return CompletionResponse(
473
+ id=request_id,
474
+ created=created_time,
475
+ model=model_name,
476
+ choices=choices,
477
+ usage=usage,
478
+ )
479
+
480
+ def _create_completion_logprobs(
481
+ self,
482
+ token_ids: GenericSequence[int],
483
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
484
+ num_output_top_logprobs: int,
485
+ tokenizer: AnyTokenizer,
486
+ initial_text_offset: int = 0,
487
+ ) -> CompletionLogProbs:
488
+ """Create logprobs for OpenAI Completion API."""
489
+ out_text_offset: List[int] = []
490
+ out_token_logprobs: List[Optional[float]] = []
491
+ out_tokens: List[str] = []
492
+ out_top_logprobs: List[Optional[Dict[str, float]]] = []
493
+
494
+ last_token_len = 0
495
+
496
+ for i, token_id in enumerate(token_ids):
497
+ step_top_logprobs = top_logprobs[i]
498
+ if step_top_logprobs is None:
499
+ token = tokenizer.decode(token_id)
500
+ if self.return_tokens_as_token_ids:
501
+ token = f"token_id:{token_id}"
502
+
503
+ out_tokens.append(token)
504
+ out_token_logprobs.append(None)
505
+ out_top_logprobs.append(None)
506
+ else:
507
+ step_token = step_top_logprobs[token_id]
508
+
509
+ token = self._get_decoded_token(
510
+ step_token,
511
+ token_id,
512
+ tokenizer,
513
+ return_as_token_id=self.return_tokens_as_token_ids,
514
+ )
515
+ token_logprob = max(step_token.logprob, -9999.0)
516
+
517
+ out_tokens.append(token)
518
+ out_token_logprobs.append(token_logprob)
519
+
520
+ # makes sure to add the top num_output_top_logprobs + 1
521
+ # logprobs, as defined in the openai API
522
+ # (cf. https://github.com/openai/openai-openapi/blob/
523
+ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
524
+ out_top_logprobs.append({
525
+ # Convert float("-inf") to the
526
+ # JSON-serializable float that OpenAI uses
527
+ self._get_decoded_token(top_lp[1],
528
+ top_lp[0],
529
+ tokenizer,
530
+ return_as_token_id=self.return_tokens_as_token_ids):
531
+ max(top_lp[1].logprob, -9999.0)
532
+ for i, top_lp in enumerate(step_top_logprobs.items())
533
+ if num_output_top_logprobs >= i
534
+ })
535
+
536
+ if len(out_text_offset) == 0:
537
+ out_text_offset.append(initial_text_offset)
538
+ else:
539
+ out_text_offset.append(out_text_offset[-1] + last_token_len)
540
+ last_token_len = len(token)
541
+
542
+ return CompletionLogProbs(
543
+ text_offset=out_text_offset,
544
+ token_logprobs=out_token_logprobs,
545
+ tokens=out_tokens,
546
+ top_logprobs=out_top_logprobs,
547
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_embedding.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import base64
5
+ import time
6
+ from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
7
+
8
+ import numpy as np
9
+ from fastapi import Request
10
+ from typing_extensions import assert_never
11
+
12
+ from vllm.config import ModelConfig
13
+ from vllm.engine.protocol import EngineClient
14
+ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
15
+ from vllm.entrypoints.logger import RequestLogger
16
+ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
17
+ EmbeddingRequest,
18
+ EmbeddingResponse,
19
+ EmbeddingResponseData,
20
+ ErrorResponse, UsageInfo)
21
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
22
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
23
+ from vllm.logger import init_logger
24
+ from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
25
+ PoolingRequestOutput)
26
+ from vllm.utils import merge_async_iterators
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ def _get_embedding(
32
+ output: EmbeddingOutput,
33
+ encoding_format: Literal["float", "base64"],
34
+ ) -> Union[List[float], str]:
35
+ if encoding_format == "float":
36
+ return output.embedding
37
+ elif encoding_format == "base64":
38
+ # Force to use float32 for base64 encoding
39
+ # to match the OpenAI python client behavior
40
+ embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
41
+ return base64.b64encode(embedding_bytes).decode("utf-8")
42
+
43
+ assert_never(encoding_format)
44
+
45
+
46
+ class OpenAIServingEmbedding(OpenAIServing):
47
+
48
+ def __init__(
49
+ self,
50
+ engine_client: EngineClient,
51
+ model_config: ModelConfig,
52
+ models: OpenAIServingModels,
53
+ *,
54
+ request_logger: Optional[RequestLogger],
55
+ chat_template: Optional[str],
56
+ chat_template_content_format: ChatTemplateContentFormatOption,
57
+ ) -> None:
58
+ super().__init__(engine_client=engine_client,
59
+ model_config=model_config,
60
+ models=models,
61
+ request_logger=request_logger)
62
+
63
+ self.chat_template = chat_template
64
+ self.chat_template_content_format: Final = chat_template_content_format
65
+
66
+ async def create_embedding(
67
+ self,
68
+ request: EmbeddingRequest,
69
+ raw_request: Optional[Request] = None,
70
+ ) -> Union[EmbeddingResponse, ErrorResponse]:
71
+ """
72
+ Embedding API similar to OpenAI's API.
73
+
74
+ See https://platform.openai.com/docs/api-reference/embeddings/create
75
+ for the API specification. This API mimics the OpenAI Embedding API.
76
+ """
77
+ error_check_ret = await self._check_model(request)
78
+ if error_check_ret is not None:
79
+ return error_check_ret
80
+
81
+ encoding_format = request.encoding_format
82
+ if request.dimensions is not None:
83
+ return self.create_error_response(
84
+ "dimensions is currently not supported")
85
+
86
+ model_name = request.model
87
+ request_id = f"embd-{self._base_request_id(raw_request)}"
88
+ created_time = int(time.time())
89
+
90
+ truncate_prompt_tokens = None
91
+
92
+ if request.truncate_prompt_tokens is not None:
93
+ if request.truncate_prompt_tokens <= self.max_model_len:
94
+ truncate_prompt_tokens = request.truncate_prompt_tokens
95
+ else:
96
+ return self.create_error_response(
97
+ "truncate_prompt_tokens value is "
98
+ "greater than max_model_len."
99
+ " Please, select a smaller truncation size.")
100
+
101
+ try:
102
+ (
103
+ lora_request,
104
+ prompt_adapter_request,
105
+ ) = self._maybe_get_adapters(request)
106
+
107
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
108
+
109
+ if prompt_adapter_request is not None:
110
+ raise NotImplementedError("Prompt adapter is not supported "
111
+ "for embedding models")
112
+
113
+ if isinstance(request, EmbeddingChatRequest):
114
+ (
115
+ _,
116
+ request_prompts,
117
+ engine_prompts,
118
+ ) = await self._preprocess_chat(
119
+ request,
120
+ tokenizer,
121
+ request.messages,
122
+ chat_template=request.chat_template or self.chat_template,
123
+ chat_template_content_format=self.
124
+ chat_template_content_format,
125
+ # In embedding requests, we are not generating tokens,
126
+ # so there is no need to append extra tokens to the input
127
+ add_generation_prompt=False,
128
+ continue_final_message=False,
129
+ truncate_prompt_tokens=truncate_prompt_tokens,
130
+ add_special_tokens=request.add_special_tokens,
131
+ )
132
+ else:
133
+ (request_prompts,
134
+ engine_prompts) = await self._preprocess_completion(
135
+ request,
136
+ tokenizer,
137
+ request.input,
138
+ truncate_prompt_tokens=truncate_prompt_tokens,
139
+ add_special_tokens=request.add_special_tokens,
140
+ )
141
+ except ValueError as e:
142
+ logger.exception("Error in preprocessing prompt inputs")
143
+ return self.create_error_response(str(e))
144
+
145
+ # Schedule the request and get the result generator.
146
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
147
+ try:
148
+ pooling_params = request.to_pooling_params()
149
+
150
+ for i, engine_prompt in enumerate(engine_prompts):
151
+ request_id_item = f"{request_id}-{i}"
152
+
153
+ self._log_inputs(request_id_item,
154
+ request_prompts[i],
155
+ params=pooling_params,
156
+ lora_request=lora_request,
157
+ prompt_adapter_request=prompt_adapter_request)
158
+
159
+ trace_headers = (None if raw_request is None else await
160
+ self._get_trace_headers(raw_request.headers))
161
+
162
+ generator = self.engine_client.encode(
163
+ engine_prompt,
164
+ pooling_params,
165
+ request_id_item,
166
+ lora_request=lora_request,
167
+ trace_headers=trace_headers,
168
+ priority=request.priority,
169
+ )
170
+
171
+ generators.append(generator)
172
+ except ValueError as e:
173
+ # TODO: Use a vllm-specific Validation Error
174
+ return self.create_error_response(str(e))
175
+
176
+ result_generator = merge_async_iterators(*generators)
177
+
178
+ num_prompts = len(engine_prompts)
179
+
180
+ # Non-streaming response
181
+ final_res_batch: List[Optional[PoolingRequestOutput]]
182
+ final_res_batch = [None] * num_prompts
183
+ try:
184
+ async for i, res in result_generator:
185
+ final_res_batch[i] = res
186
+
187
+ assert all(final_res is not None for final_res in final_res_batch)
188
+
189
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
190
+ final_res_batch)
191
+
192
+ response = self.request_output_to_embedding_response(
193
+ final_res_batch_checked,
194
+ request_id,
195
+ created_time,
196
+ model_name,
197
+ encoding_format,
198
+ )
199
+ except asyncio.CancelledError:
200
+ return self.create_error_response("Client disconnected")
201
+ except ValueError as e:
202
+ # TODO: Use a vllm-specific Validation Error
203
+ return self.create_error_response(str(e))
204
+
205
+ return response
206
+
207
+ def request_output_to_embedding_response(
208
+ self,
209
+ final_res_batch: List[PoolingRequestOutput],
210
+ request_id: str,
211
+ created_time: int,
212
+ model_name: str,
213
+ encoding_format: Literal["float", "base64"],
214
+ ) -> EmbeddingResponse:
215
+ items: List[EmbeddingResponseData] = []
216
+ num_prompt_tokens = 0
217
+
218
+ for idx, final_res in enumerate(final_res_batch):
219
+ embedding_res = EmbeddingRequestOutput.from_base(final_res)
220
+
221
+ item = EmbeddingResponseData(
222
+ index=idx,
223
+ embedding=_get_embedding(embedding_res.outputs,
224
+ encoding_format),
225
+ )
226
+ prompt_token_ids = final_res.prompt_token_ids
227
+
228
+ items.append(item)
229
+ num_prompt_tokens += len(prompt_token_ids)
230
+
231
+ usage = UsageInfo(
232
+ prompt_tokens=num_prompt_tokens,
233
+ total_tokens=num_prompt_tokens,
234
+ )
235
+
236
+ return EmbeddingResponse(
237
+ id=request_id,
238
+ created=created_time,
239
+ model=model_name,
240
+ data=items,
241
+ usage=usage,
242
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_engine.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ from concurrent.futures.thread import ThreadPoolExecutor
5
+ from http import HTTPStatus
6
+ from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
7
+ Optional, Sequence, Tuple, TypedDict, Union)
8
+
9
+ from fastapi import Request
10
+ from pydantic import Field
11
+ from starlette.datastructures import Headers
12
+ from typing_extensions import Annotated
13
+
14
+ from vllm.config import ModelConfig
15
+ from vllm.engine.protocol import EngineClient
16
+ # yapf conflicts with isort for this block
17
+ # yapf: disable
18
+ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19
+ ChatTemplateContentFormatOption,
20
+ ConversationMessage,
21
+ apply_hf_chat_template,
22
+ apply_mistral_chat_template,
23
+ parse_chat_messages_futures,
24
+ resolve_chat_template_content_format)
25
+ from vllm.entrypoints.logger import RequestLogger
26
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
27
+ CompletionRequest,
28
+ DetokenizeRequest,
29
+ EmbeddingChatRequest,
30
+ EmbeddingCompletionRequest,
31
+ ErrorResponse, RerankRequest,
32
+ ScoreRequest,
33
+ TokenizeChatRequest,
34
+ TokenizeCompletionRequest)
35
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
36
+ from vllm.entrypoints.openai.tool_parsers import ToolParser
37
+ # yapf: enable
38
+ from vllm.inputs import TokensPrompt
39
+ from vllm.inputs.parse import parse_and_batch_prompt
40
+ from vllm.logger import init_logger
41
+ from vllm.lora.request import LoRARequest
42
+ from vllm.pooling_params import PoolingParams
43
+ from vllm.prompt_adapter.request import PromptAdapterRequest
44
+ from vllm.sampling_params import BeamSearchParams, SamplingParams
45
+ from vllm.sequence import Logprob
46
+ from vllm.tracing import (contains_trace_headers, extract_trace_headers,
47
+ log_tracing_disabled_warning)
48
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
49
+ from vllm.utils import is_list_of, make_async, random_uuid
50
+
51
+ logger = init_logger(__name__)
52
+
53
+ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
54
+ EmbeddingCompletionRequest, ScoreRequest,
55
+ TokenizeCompletionRequest]
56
+
57
+ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
58
+ TokenizeChatRequest]
59
+
60
+ AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
61
+
62
+
63
+ class TextTokensPrompt(TypedDict):
64
+ prompt: str
65
+ prompt_token_ids: List[int]
66
+
67
+
68
+ RequestPrompt = Union[List[int], str, TextTokensPrompt]
69
+
70
+
71
+ class OpenAIServing:
72
+
73
+ def __init__(
74
+ self,
75
+ engine_client: EngineClient,
76
+ model_config: ModelConfig,
77
+ models: OpenAIServingModels,
78
+ *,
79
+ request_logger: Optional[RequestLogger],
80
+ return_tokens_as_token_ids: bool = False,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.engine_client = engine_client
85
+ self.model_config = model_config
86
+ self.max_model_len = model_config.max_model_len
87
+
88
+ self.models = models
89
+
90
+ self.request_logger = request_logger
91
+ self.return_tokens_as_token_ids = return_tokens_as_token_ids
92
+
93
+ self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
94
+
95
+ self._tokenize_prompt_input_async = make_async(
96
+ self._tokenize_prompt_input, executor=self._tokenizer_executor)
97
+ self._tokenize_prompt_input_or_inputs_async = make_async(
98
+ self._tokenize_prompt_input_or_inputs,
99
+ executor=self._tokenizer_executor)
100
+
101
+ def create_error_response(
102
+ self,
103
+ message: str,
104
+ err_type: str = "BadRequestError",
105
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
106
+ return ErrorResponse(message=message,
107
+ type=err_type,
108
+ code=status_code.value)
109
+
110
+ def create_streaming_error_response(
111
+ self,
112
+ message: str,
113
+ err_type: str = "BadRequestError",
114
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
115
+ json_str = json.dumps({
116
+ "error":
117
+ self.create_error_response(message=message,
118
+ err_type=err_type,
119
+ status_code=status_code).model_dump()
120
+ })
121
+ return json_str
122
+
123
+ async def _check_model(
124
+ self,
125
+ request: AnyRequest,
126
+ ) -> Optional[ErrorResponse]:
127
+ if self._is_model_supported(request.model):
128
+ return None
129
+ if request.model in [
130
+ lora.lora_name for lora in self.models.lora_requests
131
+ ]:
132
+ return None
133
+ if request.model in [
134
+ prompt_adapter.prompt_adapter_name
135
+ for prompt_adapter in self.models.prompt_adapter_requests
136
+ ]:
137
+ return None
138
+ return self.create_error_response(
139
+ message=f"The model `{request.model}` does not exist.",
140
+ err_type="NotFoundError",
141
+ status_code=HTTPStatus.NOT_FOUND)
142
+
143
+ def _maybe_get_adapters(
144
+ self, request: AnyRequest
145
+ ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
146
+ None, PromptAdapterRequest]]:
147
+ if self._is_model_supported(request.model):
148
+ return None, None
149
+ for lora in self.models.lora_requests:
150
+ if request.model == lora.lora_name:
151
+ return lora, None
152
+ for prompt_adapter in self.models.prompt_adapter_requests:
153
+ if request.model == prompt_adapter.prompt_adapter_name:
154
+ return None, prompt_adapter
155
+ # if _check_model has been called earlier, this will be unreachable
156
+ raise ValueError(f"The model `{request.model}` does not exist.")
157
+
158
+ def _normalize_prompt_text_to_input(
159
+ self,
160
+ request: AnyRequest,
161
+ tokenizer: AnyTokenizer,
162
+ prompt: str,
163
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
164
+ add_special_tokens: bool,
165
+ ) -> TextTokensPrompt:
166
+ if (self.model_config.encoder_config is not None
167
+ and self.model_config.encoder_config.get(
168
+ "do_lower_case", False)):
169
+ prompt = prompt.lower()
170
+
171
+ if truncate_prompt_tokens is None:
172
+ encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
173
+ else:
174
+ encoded = tokenizer(prompt,
175
+ add_special_tokens=add_special_tokens,
176
+ truncation=True,
177
+ max_length=truncate_prompt_tokens)
178
+
179
+ input_ids = encoded.input_ids
180
+
181
+ input_text = prompt
182
+
183
+ return self._validate_input(request, input_ids, input_text)
184
+
185
+ def _normalize_prompt_tokens_to_input(
186
+ self,
187
+ request: AnyRequest,
188
+ tokenizer: AnyTokenizer,
189
+ prompt_ids: List[int],
190
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
191
+ ) -> TextTokensPrompt:
192
+ if truncate_prompt_tokens is None:
193
+ input_ids = prompt_ids
194
+ else:
195
+ input_ids = prompt_ids[-truncate_prompt_tokens:]
196
+
197
+ input_text = tokenizer.decode(input_ids)
198
+
199
+ return self._validate_input(request, input_ids, input_text)
200
+
201
+ def _validate_input(
202
+ self,
203
+ request: AnyRequest,
204
+ input_ids: List[int],
205
+ input_text: str,
206
+ ) -> TextTokensPrompt:
207
+ token_num = len(input_ids)
208
+
209
+ # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
210
+ if isinstance(request,
211
+ (EmbeddingChatRequest, EmbeddingCompletionRequest,
212
+ ScoreRequest, RerankRequest)):
213
+
214
+ operation = "score" if isinstance(request, ScoreRequest) \
215
+ else "embedding generation"
216
+ if token_num > self.max_model_len:
217
+ raise ValueError(
218
+ f"This model's maximum context length is "
219
+ f"{self.max_model_len} tokens. However, you requested "
220
+ f"{token_num} tokens in the input for {operation}. "
221
+ f"Please reduce the length of the input.")
222
+ return TextTokensPrompt(prompt=input_text,
223
+ prompt_token_ids=input_ids)
224
+
225
+ # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
226
+ # and does not require model context length validation
227
+ if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
228
+ DetokenizeRequest)):
229
+ return TextTokensPrompt(prompt=input_text,
230
+ prompt_token_ids=input_ids)
231
+
232
+ # chat completion endpoint supports max_completion_tokens
233
+ if isinstance(request, ChatCompletionRequest):
234
+ # TODO(#9845): remove max_tokens when field dropped from OpenAI API
235
+ max_tokens = request.max_completion_tokens or request.max_tokens
236
+ else:
237
+ max_tokens = request.max_tokens
238
+ if max_tokens is None:
239
+ if token_num >= self.max_model_len:
240
+ raise ValueError(
241
+ f"This model's maximum context length is "
242
+ f"{self.max_model_len} tokens. However, you requested "
243
+ f"{token_num} tokens in the messages, "
244
+ f"Please reduce the length of the messages.")
245
+ elif token_num + max_tokens > self.max_model_len:
246
+ raise ValueError(
247
+ f"This model's maximum context length is "
248
+ f"{self.max_model_len} tokens. However, you requested "
249
+ f"{max_tokens + token_num} tokens "
250
+ f"({token_num} in the messages, "
251
+ f"{max_tokens} in the completion). "
252
+ f"Please reduce the length of the messages or completion.")
253
+
254
+ return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
255
+
256
+ def _tokenize_prompt_input(
257
+ self,
258
+ request: AnyRequest,
259
+ tokenizer: AnyTokenizer,
260
+ prompt_input: Union[str, List[int]],
261
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
262
+ add_special_tokens: bool = True,
263
+ ) -> TextTokensPrompt:
264
+ """
265
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
266
+ that assumes single input.
267
+ """
268
+ return next(
269
+ self._tokenize_prompt_inputs(
270
+ request,
271
+ tokenizer,
272
+ [prompt_input],
273
+ truncate_prompt_tokens=truncate_prompt_tokens,
274
+ add_special_tokens=add_special_tokens,
275
+ ))
276
+
277
+ def _tokenize_prompt_inputs(
278
+ self,
279
+ request: AnyRequest,
280
+ tokenizer: AnyTokenizer,
281
+ prompt_inputs: Iterable[Union[str, List[int]]],
282
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
283
+ add_special_tokens: bool = True,
284
+ ) -> Iterator[TextTokensPrompt]:
285
+ """
286
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
287
+ that assumes multiple inputs.
288
+ """
289
+ for text in prompt_inputs:
290
+ if isinstance(text, str):
291
+ yield self._normalize_prompt_text_to_input(
292
+ request,
293
+ tokenizer,
294
+ prompt=text,
295
+ truncate_prompt_tokens=truncate_prompt_tokens,
296
+ add_special_tokens=add_special_tokens,
297
+ )
298
+ else:
299
+ yield self._normalize_prompt_tokens_to_input(
300
+ request,
301
+ tokenizer,
302
+ prompt_ids=text,
303
+ truncate_prompt_tokens=truncate_prompt_tokens,
304
+ )
305
+
306
+ def _tokenize_prompt_input_or_inputs(
307
+ self,
308
+ request: AnyRequest,
309
+ tokenizer: AnyTokenizer,
310
+ input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
311
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
312
+ add_special_tokens: bool = True,
313
+ ) -> List[TextTokensPrompt]:
314
+ """
315
+ Tokenize/detokenize depending on the input format.
316
+
317
+ According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
318
+ , each input can be a string or array of tokens. Note that each request
319
+ can pass one or more inputs.
320
+ """
321
+ # Although our type checking is based on mypy,
322
+ # VSCode Pyright extension should still work properly
323
+ # "is True" is required for Pyright to perform type narrowing
324
+ # See: https://github.com/microsoft/pyright/issues/7672
325
+ return [
326
+ self._normalize_prompt_text_to_input(
327
+ request,
328
+ tokenizer,
329
+ prompt=prompt_input["content"],
330
+ truncate_prompt_tokens=truncate_prompt_tokens,
331
+ add_special_tokens=add_special_tokens)
332
+ if prompt_input["is_tokens"] is False else
333
+ self._normalize_prompt_tokens_to_input(
334
+ request,
335
+ tokenizer,
336
+ prompt_ids=prompt_input["content"],
337
+ truncate_prompt_tokens=truncate_prompt_tokens)
338
+ for prompt_input in parse_and_batch_prompt(input_or_inputs)
339
+ ]
340
+
341
+ async def _preprocess_completion(
342
+ self,
343
+ request: CompletionLikeRequest,
344
+ tokenizer: AnyTokenizer,
345
+ input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
346
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
347
+ add_special_tokens: bool = True,
348
+ ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
349
+ request_prompts = await self._tokenize_prompt_input_or_inputs_async(
350
+ request,
351
+ tokenizer,
352
+ input_or_inputs,
353
+ truncate_prompt_tokens=truncate_prompt_tokens,
354
+ add_special_tokens=add_special_tokens,
355
+ )
356
+
357
+ engine_prompts = [
358
+ TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
359
+ for request_prompt in request_prompts
360
+ ]
361
+
362
+ return request_prompts, engine_prompts
363
+
364
+ async def _preprocess_chat(
365
+ self,
366
+ request: ChatLikeRequest,
367
+ tokenizer: AnyTokenizer,
368
+ messages: List[ChatCompletionMessageParam],
369
+ chat_template: Optional[str],
370
+ chat_template_content_format: ChatTemplateContentFormatOption,
371
+ add_generation_prompt: bool = True,
372
+ continue_final_message: bool = False,
373
+ tool_dicts: Optional[List[Dict[str, Any]]] = None,
374
+ documents: Optional[List[Dict[str, str]]] = None,
375
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
376
+ tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
377
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
378
+ add_special_tokens: bool = False,
379
+ ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
380
+ List[TokensPrompt]]:
381
+ resolved_content_format = resolve_chat_template_content_format(
382
+ chat_template,
383
+ chat_template_content_format,
384
+ tokenizer,
385
+ )
386
+ conversation, mm_data_future = parse_chat_messages_futures(
387
+ messages,
388
+ self.model_config,
389
+ tokenizer,
390
+ content_format=resolved_content_format,
391
+ )
392
+
393
+ _chat_template_kwargs: Dict[str, Any] = dict(
394
+ chat_template=chat_template,
395
+ add_generation_prompt=add_generation_prompt,
396
+ continue_final_message=continue_final_message,
397
+ tools=tool_dicts,
398
+ documents=documents,
399
+ )
400
+ _chat_template_kwargs.update(chat_template_kwargs or {})
401
+
402
+ request_prompt: Union[str, List[int]]
403
+ is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
404
+ if is_mistral_tokenizer:
405
+ request_prompt = apply_mistral_chat_template(
406
+ tokenizer,
407
+ messages=messages,
408
+ **_chat_template_kwargs,
409
+ )
410
+ else:
411
+ request_prompt = apply_hf_chat_template(
412
+ tokenizer,
413
+ conversation=conversation,
414
+ **_chat_template_kwargs,
415
+ )
416
+
417
+ mm_data = await mm_data_future
418
+
419
+ # tool parsing is done only if a tool_parser has been set and if
420
+ # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
421
+ # is set, we want to prevent parsing a tool_call hallucinated by the LLM
422
+ should_parse_tools = tool_parser is not None and (hasattr(
423
+ request, "tool_choice") and request.tool_choice != "none")
424
+
425
+ if should_parse_tools:
426
+ if not isinstance(request, ChatCompletionRequest):
427
+ msg = "Tool usage is only supported for Chat Completions API"
428
+ raise NotImplementedError(msg)
429
+
430
+ request = tool_parser(tokenizer).adjust_request( # type: ignore
431
+ request=request)
432
+
433
+ if isinstance(request_prompt, str):
434
+ prompt_inputs = await self._tokenize_prompt_input_async(
435
+ request,
436
+ tokenizer,
437
+ request_prompt,
438
+ truncate_prompt_tokens=truncate_prompt_tokens,
439
+ add_special_tokens=add_special_tokens,
440
+ )
441
+ else:
442
+ # For MistralTokenizer
443
+ assert is_list_of(request_prompt, int), (
444
+ "Prompt has to be either a string or a list of token ids")
445
+ prompt_inputs = TextTokensPrompt(
446
+ prompt=tokenizer.decode(request_prompt),
447
+ prompt_token_ids=request_prompt)
448
+
449
+ engine_prompt = TokensPrompt(
450
+ prompt_token_ids=prompt_inputs["prompt_token_ids"])
451
+ if mm_data is not None:
452
+ engine_prompt["multi_modal_data"] = mm_data
453
+
454
+ return conversation, [request_prompt], [engine_prompt]
455
+
456
+ def _log_inputs(
457
+ self,
458
+ request_id: str,
459
+ inputs: RequestPrompt,
460
+ params: Optional[Union[SamplingParams, PoolingParams,
461
+ BeamSearchParams]],
462
+ lora_request: Optional[LoRARequest],
463
+ prompt_adapter_request: Optional[PromptAdapterRequest],
464
+ ) -> None:
465
+ if self.request_logger is None:
466
+ return
467
+
468
+ if isinstance(inputs, str):
469
+ prompt = inputs
470
+ prompt_token_ids = None
471
+ elif isinstance(inputs, list):
472
+ prompt = None
473
+ prompt_token_ids = inputs
474
+ else:
475
+ prompt = inputs["prompt"]
476
+ prompt_token_ids = inputs["prompt_token_ids"]
477
+
478
+ self.request_logger.log_inputs(
479
+ request_id,
480
+ prompt,
481
+ prompt_token_ids,
482
+ params=params,
483
+ lora_request=lora_request,
484
+ prompt_adapter_request=prompt_adapter_request,
485
+ )
486
+
487
+ async def _get_trace_headers(
488
+ self,
489
+ headers: Headers,
490
+ ) -> Optional[Mapping[str, str]]:
491
+ is_tracing_enabled = await self.engine_client.is_tracing_enabled()
492
+
493
+ if is_tracing_enabled:
494
+ return extract_trace_headers(headers)
495
+
496
+ if contains_trace_headers(headers):
497
+ log_tracing_disabled_warning()
498
+
499
+ return None
500
+
501
+ @staticmethod
502
+ def _base_request_id(raw_request: Optional[Request],
503
+ default: Optional[str] = None) -> Optional[str]:
504
+ """Pulls the request id to use from a header, if provided"""
505
+ default = default or random_uuid()
506
+ if raw_request is None:
507
+ return default
508
+
509
+ return raw_request.headers.get("X-Request-Id", default)
510
+
511
+ @staticmethod
512
+ def _get_decoded_token(logprob: Logprob,
513
+ token_id: int,
514
+ tokenizer: AnyTokenizer,
515
+ return_as_token_id: bool = False) -> str:
516
+ if return_as_token_id:
517
+ return f"token_id:{token_id}"
518
+
519
+ if logprob.decoded_token is not None:
520
+ return logprob.decoded_token
521
+ return tokenizer.decode(token_id)
522
+
523
+ def _is_model_supported(self, model_name):
524
+ return self.models.is_base_model(model_name)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_models.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import pathlib
5
+ from dataclasses import dataclass
6
+ from http import HTTPStatus
7
+ from typing import List, Optional, Union
8
+
9
+ from vllm.config import ModelConfig
10
+ from vllm.engine.protocol import EngineClient
11
+ from vllm.entrypoints.openai.protocol import (ErrorResponse,
12
+ LoadLoraAdapterRequest,
13
+ ModelCard, ModelList,
14
+ ModelPermission,
15
+ UnloadLoraAdapterRequest)
16
+ from vllm.logger import init_logger
17
+ from vllm.lora.request import LoRARequest
18
+ from vllm.prompt_adapter.request import PromptAdapterRequest
19
+ from vllm.utils import AtomicCounter
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class BaseModelPath:
26
+ name: str
27
+ model_path: str
28
+
29
+
30
+ @dataclass
31
+ class PromptAdapterPath:
32
+ name: str
33
+ local_path: str
34
+
35
+
36
+ @dataclass
37
+ class LoRAModulePath:
38
+ name: str
39
+ path: str
40
+ base_model_name: Optional[str] = None
41
+
42
+
43
+ class OpenAIServingModels:
44
+ """Shared instance to hold data about the loaded base model(s) and adapters.
45
+
46
+ Handles the routes:
47
+ - /v1/models
48
+ - /v1/load_lora_adapter
49
+ - /v1/unload_lora_adapter
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ engine_client: EngineClient,
55
+ model_config: ModelConfig,
56
+ base_model_paths: List[BaseModelPath],
57
+ *,
58
+ lora_modules: Optional[List[LoRAModulePath]] = None,
59
+ prompt_adapters: Optional[List[PromptAdapterPath]] = None,
60
+ ):
61
+ super().__init__()
62
+
63
+ self.base_model_paths = base_model_paths
64
+ self.max_model_len = model_config.max_model_len
65
+ self.engine_client = engine_client
66
+
67
+ self.static_lora_modules = lora_modules
68
+ self.lora_requests: List[LoRARequest] = []
69
+ self.lora_id_counter = AtomicCounter(0)
70
+
71
+ self.prompt_adapter_requests = []
72
+ if prompt_adapters is not None:
73
+ for i, prompt_adapter in enumerate(prompt_adapters, start=1):
74
+ with pathlib.Path(prompt_adapter.local_path,
75
+ "adapter_config.json").open() as f:
76
+ adapter_config = json.load(f)
77
+ num_virtual_tokens = adapter_config["num_virtual_tokens"]
78
+ self.prompt_adapter_requests.append(
79
+ PromptAdapterRequest(
80
+ prompt_adapter_name=prompt_adapter.name,
81
+ prompt_adapter_id=i,
82
+ prompt_adapter_local_path=prompt_adapter.local_path,
83
+ prompt_adapter_num_virtual_tokens=num_virtual_tokens))
84
+
85
+ async def init_static_loras(self):
86
+ """Loads all static LoRA modules.
87
+ Raises if any fail to load"""
88
+ if self.static_lora_modules is None:
89
+ return
90
+ for lora in self.static_lora_modules:
91
+ load_request = LoadLoraAdapterRequest(lora_path=lora.path,
92
+ lora_name=lora.name)
93
+ load_result = await self.load_lora_adapter(
94
+ request=load_request, base_model_name=lora.base_model_name)
95
+ if isinstance(load_result, ErrorResponse):
96
+ raise ValueError(load_result.message)
97
+
98
+ def is_base_model(self, model_name):
99
+ return any(model.name == model_name for model in self.base_model_paths)
100
+
101
+ def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
102
+ """Returns the appropriate model name depending on the availability
103
+ and support of the LoRA or base model.
104
+ Parameters:
105
+ - lora: LoRARequest that contain a base_model_name.
106
+ Returns:
107
+ - str: The name of the base model or the first available model path.
108
+ """
109
+ if lora_request is not None:
110
+ return lora_request.lora_name
111
+ return self.base_model_paths[0].name
112
+
113
+ async def show_available_models(self) -> ModelList:
114
+ """Show available models. This includes the base model and all
115
+ adapters"""
116
+ model_cards = [
117
+ ModelCard(id=base_model.name,
118
+ max_model_len=self.max_model_len,
119
+ root=base_model.model_path,
120
+ permission=[ModelPermission()])
121
+ for base_model in self.base_model_paths
122
+ ]
123
+ lora_cards = [
124
+ ModelCard(id=lora.lora_name,
125
+ root=lora.local_path,
126
+ parent=lora.base_model_name if lora.base_model_name else
127
+ self.base_model_paths[0].name,
128
+ permission=[ModelPermission()])
129
+ for lora in self.lora_requests
130
+ ]
131
+ prompt_adapter_cards = [
132
+ ModelCard(id=prompt_adapter.prompt_adapter_name,
133
+ root=self.base_model_paths[0].name,
134
+ permission=[ModelPermission()])
135
+ for prompt_adapter in self.prompt_adapter_requests
136
+ ]
137
+ model_cards.extend(lora_cards)
138
+ model_cards.extend(prompt_adapter_cards)
139
+ return ModelList(data=model_cards)
140
+
141
+ async def load_lora_adapter(
142
+ self,
143
+ request: LoadLoraAdapterRequest,
144
+ base_model_name: Optional[str] = None
145
+ ) -> Union[ErrorResponse, str]:
146
+ error_check_ret = await self._check_load_lora_adapter_request(request)
147
+ if error_check_ret is not None:
148
+ return error_check_ret
149
+
150
+ lora_name, lora_path = request.lora_name, request.lora_path
151
+ unique_id = self.lora_id_counter.inc(1)
152
+ lora_request = LoRARequest(lora_name=lora_name,
153
+ lora_int_id=unique_id,
154
+ lora_path=lora_path)
155
+ if base_model_name is not None and self.is_base_model(base_model_name):
156
+ lora_request.base_model_name = base_model_name
157
+
158
+ # Validate that the adapter can be loaded into the engine
159
+ # This will also pre-load it for incoming requests
160
+ try:
161
+ await self.engine_client.add_lora(lora_request)
162
+ except BaseException as e:
163
+ error_type = "BadRequestError"
164
+ status_code = HTTPStatus.BAD_REQUEST
165
+ if isinstance(e, ValueError) and "No adapter found" in str(e):
166
+ error_type = "NotFoundError"
167
+ status_code = HTTPStatus.NOT_FOUND
168
+
169
+ return create_error_response(message=str(e),
170
+ err_type=error_type,
171
+ status_code=status_code)
172
+
173
+ self.lora_requests.append(lora_request)
174
+ logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
175
+ lora_path)
176
+ return f"Success: LoRA adapter '{lora_name}' added successfully."
177
+
178
+ async def unload_lora_adapter(
179
+ self,
180
+ request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
181
+ error_check_ret = await self._check_unload_lora_adapter_request(request
182
+ )
183
+ if error_check_ret is not None:
184
+ return error_check_ret
185
+
186
+ lora_name = request.lora_name
187
+ self.lora_requests = [
188
+ lora_request for lora_request in self.lora_requests
189
+ if lora_request.lora_name != lora_name
190
+ ]
191
+ logger.info("Removed LoRA adapter: name '%s'", lora_name)
192
+ return f"Success: LoRA adapter '{lora_name}' removed successfully."
193
+
194
+ async def _check_load_lora_adapter_request(
195
+ self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
196
+ # Check if both 'lora_name' and 'lora_path' are provided
197
+ if not request.lora_name or not request.lora_path:
198
+ return create_error_response(
199
+ message="Both 'lora_name' and 'lora_path' must be provided.",
200
+ err_type="InvalidUserInput",
201
+ status_code=HTTPStatus.BAD_REQUEST)
202
+
203
+ # Check if the lora adapter with the given name already exists
204
+ if any(lora_request.lora_name == request.lora_name
205
+ for lora_request in self.lora_requests):
206
+ return create_error_response(
207
+ message=
208
+ f"The lora adapter '{request.lora_name}' has already been "
209
+ "loaded.",
210
+ err_type="InvalidUserInput",
211
+ status_code=HTTPStatus.BAD_REQUEST)
212
+
213
+ return None
214
+
215
+ async def _check_unload_lora_adapter_request(
216
+ self,
217
+ request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
218
+ # Check if either 'lora_name' or 'lora_int_id' is provided
219
+ if not request.lora_name and not request.lora_int_id:
220
+ return create_error_response(
221
+ message=
222
+ "either 'lora_name' and 'lora_int_id' needs to be provided.",
223
+ err_type="InvalidUserInput",
224
+ status_code=HTTPStatus.BAD_REQUEST)
225
+
226
+ # Check if the lora adapter with the given name exists
227
+ if not any(lora_request.lora_name == request.lora_name
228
+ for lora_request in self.lora_requests):
229
+ return create_error_response(
230
+ message=
231
+ f"The lora adapter '{request.lora_name}' cannot be found.",
232
+ err_type="NotFoundError",
233
+ status_code=HTTPStatus.NOT_FOUND)
234
+
235
+ return None
236
+
237
+
238
+ def create_error_response(
239
+ message: str,
240
+ err_type: str = "BadRequestError",
241
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
242
+ return ErrorResponse(message=message,
243
+ type=err_type,
244
+ code=status_code.value)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_pooling.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import base64
5
+ import time
6
+ from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
7
+
8
+ import numpy as np
9
+ from fastapi import Request
10
+ from typing_extensions import assert_never
11
+
12
+ from vllm.config import ModelConfig
13
+ from vllm.engine.protocol import EngineClient
14
+ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
15
+ from vllm.entrypoints.logger import RequestLogger
16
+ from vllm.entrypoints.openai.protocol import (ErrorResponse,
17
+ PoolingChatRequest,
18
+ PoolingRequest, PoolingResponse,
19
+ PoolingResponseData, UsageInfo)
20
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
21
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
22
+ from vllm.logger import init_logger
23
+ from vllm.outputs import PoolingOutput, PoolingRequestOutput
24
+ from vllm.utils import merge_async_iterators
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ def _get_data(
30
+ output: PoolingOutput,
31
+ encoding_format: Literal["float", "base64"],
32
+ ) -> Union[List[float], str]:
33
+ if encoding_format == "float":
34
+ return output.data.tolist()
35
+ elif encoding_format == "base64":
36
+ # Force to use float32 for base64 encoding
37
+ # to match the OpenAI python client behavior
38
+ pooling_bytes = np.array(output.data, dtype="float32").tobytes()
39
+ return base64.b64encode(pooling_bytes).decode("utf-8")
40
+
41
+ assert_never(encoding_format)
42
+
43
+
44
+ class OpenAIServingPooling(OpenAIServing):
45
+
46
+ def __init__(
47
+ self,
48
+ engine_client: EngineClient,
49
+ model_config: ModelConfig,
50
+ models: OpenAIServingModels,
51
+ *,
52
+ request_logger: Optional[RequestLogger],
53
+ chat_template: Optional[str],
54
+ chat_template_content_format: ChatTemplateContentFormatOption,
55
+ ) -> None:
56
+ super().__init__(engine_client=engine_client,
57
+ model_config=model_config,
58
+ models=models,
59
+ request_logger=request_logger)
60
+
61
+ self.chat_template = chat_template
62
+ self.chat_template_content_format: Final = chat_template_content_format
63
+
64
+ async def create_pooling(
65
+ self,
66
+ request: PoolingRequest,
67
+ raw_request: Optional[Request] = None,
68
+ ) -> Union[PoolingResponse, ErrorResponse]:
69
+ """
70
+ See https://platform.openai.com/docs/api-reference/embeddings/create
71
+ for the API specification. This API mimics the OpenAI Embedding API.
72
+ """
73
+ error_check_ret = await self._check_model(request)
74
+ if error_check_ret is not None:
75
+ return error_check_ret
76
+
77
+ encoding_format = request.encoding_format
78
+ if request.dimensions is not None:
79
+ return self.create_error_response(
80
+ "dimensions is currently not supported")
81
+
82
+ model_name = request.model
83
+ request_id = f"pool-{self._base_request_id(raw_request)}"
84
+ created_time = int(time.time())
85
+
86
+ truncate_prompt_tokens = None
87
+
88
+ if request.truncate_prompt_tokens is not None:
89
+ if request.truncate_prompt_tokens <= self.max_model_len:
90
+ truncate_prompt_tokens = request.truncate_prompt_tokens
91
+ else:
92
+ return self.create_error_response(
93
+ "truncate_prompt_tokens value is "
94
+ "greater than max_model_len."
95
+ " Please, select a smaller truncation size.")
96
+
97
+ try:
98
+ (
99
+ lora_request,
100
+ prompt_adapter_request,
101
+ ) = self._maybe_get_adapters(request)
102
+
103
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
104
+
105
+ if prompt_adapter_request is not None:
106
+ raise NotImplementedError("Prompt adapter is not supported "
107
+ "for pooling models")
108
+
109
+ if isinstance(request, PoolingChatRequest):
110
+ (
111
+ _,
112
+ request_prompts,
113
+ engine_prompts,
114
+ ) = await self._preprocess_chat(
115
+ request,
116
+ tokenizer,
117
+ request.messages,
118
+ chat_template=request.chat_template or self.chat_template,
119
+ chat_template_content_format=self.
120
+ chat_template_content_format,
121
+ # In pooling requests, we are not generating tokens,
122
+ # so there is no need to append extra tokens to the input
123
+ add_generation_prompt=False,
124
+ continue_final_message=False,
125
+ truncate_prompt_tokens=truncate_prompt_tokens,
126
+ add_special_tokens=request.add_special_tokens,
127
+ )
128
+ else:
129
+ (request_prompts,
130
+ engine_prompts) = await self._preprocess_completion(
131
+ request,
132
+ tokenizer,
133
+ request.input,
134
+ truncate_prompt_tokens=truncate_prompt_tokens,
135
+ add_special_tokens=request.add_special_tokens,
136
+ )
137
+ except ValueError as e:
138
+ logger.exception("Error in preprocessing prompt inputs")
139
+ return self.create_error_response(str(e))
140
+
141
+ # Schedule the request and get the result generator.
142
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
143
+ try:
144
+ pooling_params = request.to_pooling_params()
145
+
146
+ for i, engine_prompt in enumerate(engine_prompts):
147
+ request_id_item = f"{request_id}-{i}"
148
+
149
+ self._log_inputs(request_id_item,
150
+ request_prompts[i],
151
+ params=pooling_params,
152
+ lora_request=lora_request,
153
+ prompt_adapter_request=prompt_adapter_request)
154
+
155
+ trace_headers = (None if raw_request is None else await
156
+ self._get_trace_headers(raw_request.headers))
157
+
158
+ generator = self.engine_client.encode(
159
+ engine_prompt,
160
+ pooling_params,
161
+ request_id_item,
162
+ lora_request=lora_request,
163
+ trace_headers=trace_headers,
164
+ priority=request.priority,
165
+ )
166
+
167
+ generators.append(generator)
168
+ except ValueError as e:
169
+ # TODO: Use a vllm-specific Validation Error
170
+ return self.create_error_response(str(e))
171
+
172
+ result_generator = merge_async_iterators(*generators)
173
+
174
+ num_prompts = len(engine_prompts)
175
+
176
+ # Non-streaming response
177
+ final_res_batch: List[Optional[PoolingRequestOutput]]
178
+ final_res_batch = [None] * num_prompts
179
+ try:
180
+ async for i, res in result_generator:
181
+ final_res_batch[i] = res
182
+
183
+ assert all(final_res is not None for final_res in final_res_batch)
184
+
185
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
186
+ final_res_batch)
187
+
188
+ response = self.request_output_to_pooling_response(
189
+ final_res_batch_checked,
190
+ request_id,
191
+ created_time,
192
+ model_name,
193
+ encoding_format,
194
+ )
195
+ except asyncio.CancelledError:
196
+ return self.create_error_response("Client disconnected")
197
+ except ValueError as e:
198
+ # TODO: Use a vllm-specific Validation Error
199
+ return self.create_error_response(str(e))
200
+
201
+ return response
202
+
203
+ def request_output_to_pooling_response(
204
+ self,
205
+ final_res_batch: List[PoolingRequestOutput],
206
+ request_id: str,
207
+ created_time: int,
208
+ model_name: str,
209
+ encoding_format: Literal["float", "base64"],
210
+ ) -> PoolingResponse:
211
+ items: List[PoolingResponseData] = []
212
+ num_prompt_tokens = 0
213
+
214
+ for idx, final_res in enumerate(final_res_batch):
215
+ item = PoolingResponseData(
216
+ index=idx,
217
+ data=_get_data(final_res.outputs, encoding_format),
218
+ )
219
+ prompt_token_ids = final_res.prompt_token_ids
220
+
221
+ items.append(item)
222
+ num_prompt_tokens += len(prompt_token_ids)
223
+
224
+ usage = UsageInfo(
225
+ prompt_tokens=num_prompt_tokens,
226
+ total_tokens=num_prompt_tokens,
227
+ )
228
+
229
+ return PoolingResponse(
230
+ id=request_id,
231
+ created=created_time,
232
+ model=model_name,
233
+ data=items,
234
+ usage=usage,
235
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_rerank.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
5
+
6
+ from fastapi import Request
7
+
8
+ from vllm.config import ModelConfig
9
+ from vllm.engine.protocol import EngineClient
10
+ from vllm.entrypoints.logger import RequestLogger
11
+ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
12
+ RerankRequest, RerankResponse,
13
+ RerankResult, RerankUsage)
14
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
15
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
16
+ from vllm.inputs.data import TokensPrompt
17
+ from vllm.logger import init_logger
18
+ from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
19
+ from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
20
+ from vllm.utils import make_async, merge_async_iterators
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class JinaAIServingRerank(OpenAIServing):
26
+
27
+ def __init__(
28
+ self,
29
+ engine_client: EngineClient,
30
+ model_config: ModelConfig,
31
+ models: OpenAIServingModels,
32
+ *,
33
+ request_logger: Optional[RequestLogger],
34
+ ) -> None:
35
+ super().__init__(engine_client=engine_client,
36
+ model_config=model_config,
37
+ models=models,
38
+ request_logger=request_logger)
39
+
40
+ async def do_rerank(
41
+ self,
42
+ request: RerankRequest,
43
+ raw_request: Optional[Request] = None
44
+ ) -> Union[RerankResponse, ErrorResponse]:
45
+ """
46
+ Rerank API based on JinaAI's rerank API; implements the same
47
+ API interface. Designed for compatibility with off-the-shelf
48
+ tooling, since this is a common standard for reranking APIs
49
+
50
+ See example client implementations at
51
+ https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
52
+ numerous clients use this standard.
53
+ """
54
+ error_check_ret = await self._check_model(request)
55
+ if error_check_ret is not None:
56
+ return error_check_ret
57
+
58
+ model_name = request.model
59
+ request_id = f"rerank-{self._base_request_id(raw_request)}"
60
+ truncate_prompt_tokens = request.truncate_prompt_tokens
61
+ query = request.query
62
+ documents = request.documents
63
+ request_prompts = []
64
+ engine_prompts = []
65
+ top_n = request.top_n if request.top_n > 0 else len(documents)
66
+
67
+ try:
68
+ (
69
+ lora_request,
70
+ prompt_adapter_request,
71
+ ) = self._maybe_get_adapters(request)
72
+
73
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
74
+
75
+ if prompt_adapter_request is not None:
76
+ raise NotImplementedError("Prompt adapter is not supported "
77
+ "for scoring models")
78
+
79
+ if isinstance(tokenizer, MistralTokenizer):
80
+ raise ValueError(
81
+ "MistralTokenizer not supported for cross-encoding")
82
+
83
+ if not self.model_config.is_cross_encoder:
84
+ raise ValueError("Model is not cross encoder.")
85
+
86
+ if truncate_prompt_tokens is not None and \
87
+ truncate_prompt_tokens > self.max_model_len:
88
+ raise ValueError(
89
+ f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
90
+ f"is greater than max_model_len ({self.max_model_len})."
91
+ f" Please, select a smaller truncation size.")
92
+ for doc in documents:
93
+ request_prompt = f"{query}{tokenizer.sep_token}{doc}"
94
+ tokenization_kwargs: Dict[str, Any] = {}
95
+ if truncate_prompt_tokens is not None:
96
+ tokenization_kwargs["truncation"] = True
97
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
98
+
99
+ tokenize_async = make_async(tokenizer.__call__,
100
+ executor=self._tokenizer_executor)
101
+ prompt_inputs = await tokenize_async(text=query,
102
+ text_pair=doc,
103
+ **tokenization_kwargs)
104
+
105
+ input_ids = prompt_inputs["input_ids"]
106
+ text_token_prompt = \
107
+ self._validate_input(request, input_ids, request_prompt)
108
+ engine_prompt = TokensPrompt(
109
+ prompt_token_ids=text_token_prompt["prompt_token_ids"],
110
+ token_type_ids=prompt_inputs.get("token_type_ids"))
111
+
112
+ request_prompts.append(request_prompt)
113
+ engine_prompts.append(engine_prompt)
114
+
115
+ except ValueError as e:
116
+ logger.exception("Error in preprocessing prompt inputs")
117
+ return self.create_error_response(str(e))
118
+
119
+ # Schedule the request and get the result generator.
120
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
121
+
122
+ try:
123
+ pooling_params = request.to_pooling_params()
124
+
125
+ for i, engine_prompt in enumerate(engine_prompts):
126
+ request_id_item = f"{request_id}-{i}"
127
+
128
+ self._log_inputs(request_id_item,
129
+ request_prompts[i],
130
+ params=pooling_params,
131
+ lora_request=lora_request,
132
+ prompt_adapter_request=prompt_adapter_request)
133
+
134
+ trace_headers = (None if raw_request is None else await
135
+ self._get_trace_headers(raw_request.headers))
136
+
137
+ generator = self.engine_client.encode(
138
+ engine_prompt,
139
+ pooling_params,
140
+ request_id_item,
141
+ lora_request=lora_request,
142
+ trace_headers=trace_headers,
143
+ priority=request.priority,
144
+ )
145
+
146
+ generators.append(generator)
147
+ except ValueError as e:
148
+ # TODO: Use a vllm-specific Validation Error
149
+ return self.create_error_response(str(e))
150
+ result_generator = merge_async_iterators(*generators)
151
+
152
+ num_prompts = len(engine_prompts)
153
+
154
+ # Non-streaming response
155
+ final_res_batch: List[Optional[PoolingRequestOutput]]
156
+ final_res_batch = [None] * num_prompts
157
+
158
+ try:
159
+ async for i, res in result_generator:
160
+ final_res_batch[i] = res
161
+
162
+ assert all(final_res is not None for final_res in final_res_batch)
163
+
164
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
165
+ final_res_batch)
166
+
167
+ response = self.request_output_to_rerank_response(
168
+ final_res_batch_checked, request_id, model_name, documents,
169
+ top_n)
170
+ except asyncio.CancelledError:
171
+ return self.create_error_response("Client disconnected")
172
+ except ValueError as e:
173
+ # TODO: Use a vllm-specific Validation Error
174
+ return self.create_error_response(str(e))
175
+
176
+ return response
177
+
178
+ def request_output_to_rerank_response(
179
+ self, final_res_batch: List[PoolingRequestOutput], request_id: str,
180
+ model_name: str, documents: List[str],
181
+ top_n: int) -> RerankResponse:
182
+ """
183
+ Convert the output of do_rank to a RerankResponse
184
+ """
185
+ results: List[RerankResult] = []
186
+ num_prompt_tokens = 0
187
+ for idx, final_res in enumerate(final_res_batch):
188
+ classify_res = ScoringRequestOutput.from_base(final_res)
189
+
190
+ result = RerankResult(
191
+ index=idx,
192
+ document=RerankDocument(text=documents[idx]),
193
+ relevance_score=classify_res.outputs.score,
194
+ )
195
+ results.append(result)
196
+ prompt_token_ids = final_res.prompt_token_ids
197
+ num_prompt_tokens += len(prompt_token_ids)
198
+
199
+ # sort by relevance, then return the top n if set
200
+ results.sort(key=lambda x: x.relevance_score, reverse=True)
201
+ if top_n < len(documents):
202
+ results = results[:top_n]
203
+
204
+ return RerankResponse(
205
+ id=request_id,
206
+ model=model_name,
207
+ results=results,
208
+ usage=RerankUsage(total_tokens=num_prompt_tokens))
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_score.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import time
5
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
6
+
7
+ from fastapi import Request
8
+
9
+ from vllm.config import ModelConfig
10
+ from vllm.engine.protocol import EngineClient
11
+ from vllm.entrypoints.logger import RequestLogger
12
+ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
13
+ ScoreResponse, ScoreResponseData,
14
+ UsageInfo)
15
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
16
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
17
+ from vllm.inputs.data import TokensPrompt
18
+ from vllm.logger import init_logger
19
+ from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
20
+ from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
21
+ from vllm.utils import make_async, merge_async_iterators
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
27
+ str]) -> List:
28
+ if isinstance(text_1, (str, dict)):
29
+ # Convert a single prompt to a list.
30
+ text_1 = [text_1]
31
+ text_1 = [t for t in text_1]
32
+
33
+ if isinstance(text_2, (str, dict)):
34
+ # Convert a single prompt to a list.
35
+ text_2 = [text_2]
36
+ text_2 = [t for t in text_2]
37
+ if len(text_1) > 1 and len(text_1) != len(text_2):
38
+ raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
39
+ if len(text_1) == 0:
40
+ raise ValueError("At least one text element must be given")
41
+ if len(text_2) == 0:
42
+ raise ValueError("At least one text_pair element must be given")
43
+
44
+ if len(text_1) == 1:
45
+ text_1 = text_1 * len(text_2)
46
+
47
+ return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
48
+
49
+
50
+ class OpenAIServingScores(OpenAIServing):
51
+
52
+ def __init__(
53
+ self,
54
+ engine_client: EngineClient,
55
+ model_config: ModelConfig,
56
+ models: OpenAIServingModels,
57
+ *,
58
+ request_logger: Optional[RequestLogger],
59
+ ) -> None:
60
+ super().__init__(engine_client=engine_client,
61
+ model_config=model_config,
62
+ models=models,
63
+ request_logger=request_logger)
64
+
65
+ async def create_score(
66
+ self,
67
+ request: ScoreRequest,
68
+ raw_request: Optional[Request] = None,
69
+ ) -> Union[ScoreResponse, ErrorResponse]:
70
+ """
71
+ Score API similar to Sentence Transformers cross encoder
72
+
73
+ See https://sbert.net/docs/package_reference/cross_encoder
74
+ """
75
+ error_check_ret = await self._check_model(request)
76
+ if error_check_ret is not None:
77
+ return error_check_ret
78
+
79
+ model_name = request.model
80
+ request_id = f"score-{self._base_request_id(raw_request)}"
81
+ created_time = int(time.time())
82
+ truncate_prompt_tokens = request.truncate_prompt_tokens
83
+
84
+ request_prompts = []
85
+ engine_prompts = []
86
+
87
+ try:
88
+ (
89
+ lora_request,
90
+ prompt_adapter_request,
91
+ ) = self._maybe_get_adapters(request)
92
+
93
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
94
+
95
+ if prompt_adapter_request is not None:
96
+ raise NotImplementedError("Prompt adapter is not supported "
97
+ "for scoring models")
98
+
99
+ if isinstance(tokenizer, MistralTokenizer):
100
+ raise ValueError(
101
+ "MistralTokenizer not supported for cross-encoding")
102
+
103
+ if not self.model_config.is_cross_encoder:
104
+ raise ValueError("Model is not cross encoder.")
105
+
106
+ if truncate_prompt_tokens is not None and \
107
+ truncate_prompt_tokens > self.max_model_len:
108
+ raise ValueError(
109
+ f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
110
+ f"is greater than max_model_len ({self.max_model_len})."
111
+ f" Please, select a smaller truncation size.")
112
+
113
+ input_pairs = make_pairs(request.text_1, request.text_2)
114
+ for q, t in input_pairs:
115
+ request_prompt = f"{q}{tokenizer.sep_token}{t}"
116
+
117
+ tokenization_kwargs: Dict[str, Any] = {}
118
+ if truncate_prompt_tokens is not None:
119
+ tokenization_kwargs["truncation"] = True
120
+ tokenization_kwargs["max_length"] = truncate_prompt_tokens
121
+
122
+ tokenize_async = make_async(tokenizer.__call__,
123
+ executor=self._tokenizer_executor)
124
+ prompt_inputs = await tokenize_async(text=q,
125
+ text_pair=t,
126
+ **tokenization_kwargs)
127
+
128
+ input_ids = prompt_inputs["input_ids"]
129
+ text_token_prompt = \
130
+ self._validate_input(request, input_ids, request_prompt)
131
+ engine_prompt = TokensPrompt(
132
+ prompt_token_ids=text_token_prompt["prompt_token_ids"],
133
+ token_type_ids=prompt_inputs.get("token_type_ids"))
134
+
135
+ request_prompts.append(request_prompt)
136
+ engine_prompts.append(engine_prompt)
137
+
138
+ except ValueError as e:
139
+ logger.exception("Error in preprocessing prompt inputs")
140
+ return self.create_error_response(str(e))
141
+
142
+ # Schedule the request and get the result generator.
143
+ generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
144
+
145
+ try:
146
+ pooling_params = request.to_pooling_params()
147
+
148
+ for i, engine_prompt in enumerate(engine_prompts):
149
+ request_id_item = f"{request_id}-{i}"
150
+
151
+ self._log_inputs(request_id_item,
152
+ request_prompts[i],
153
+ params=pooling_params,
154
+ lora_request=lora_request,
155
+ prompt_adapter_request=prompt_adapter_request)
156
+
157
+ trace_headers = (None if raw_request is None else await
158
+ self._get_trace_headers(raw_request.headers))
159
+
160
+ generator = self.engine_client.encode(
161
+ engine_prompt,
162
+ pooling_params,
163
+ request_id_item,
164
+ lora_request=lora_request,
165
+ trace_headers=trace_headers,
166
+ priority=request.priority,
167
+ )
168
+
169
+ generators.append(generator)
170
+ except ValueError as e:
171
+ # TODO: Use a vllm-specific Validation Error
172
+ return self.create_error_response(str(e))
173
+
174
+ result_generator = merge_async_iterators(*generators)
175
+
176
+ num_prompts = len(engine_prompts)
177
+
178
+ # Non-streaming response
179
+ final_res_batch: List[Optional[PoolingRequestOutput]]
180
+ final_res_batch = [None] * num_prompts
181
+
182
+ try:
183
+ async for i, res in result_generator:
184
+ final_res_batch[i] = res
185
+
186
+ assert all(final_res is not None for final_res in final_res_batch)
187
+
188
+ final_res_batch_checked = cast(List[PoolingRequestOutput],
189
+ final_res_batch)
190
+
191
+ response = self.request_output_to_score_response(
192
+ final_res_batch_checked,
193
+ request_id,
194
+ created_time,
195
+ model_name,
196
+ )
197
+ except asyncio.CancelledError:
198
+ return self.create_error_response("Client disconnected")
199
+ except ValueError as e:
200
+ # TODO: Use a vllm-specific Validation Error
201
+ return self.create_error_response(str(e))
202
+
203
+ return response
204
+
205
+ def request_output_to_score_response(
206
+ self,
207
+ final_res_batch: List[PoolingRequestOutput],
208
+ request_id: str,
209
+ created_time: int,
210
+ model_name: str,
211
+ ) -> ScoreResponse:
212
+ items: List[ScoreResponseData] = []
213
+ num_prompt_tokens = 0
214
+
215
+ for idx, final_res in enumerate(final_res_batch):
216
+ classify_res = ScoringRequestOutput.from_base(final_res)
217
+
218
+ item = ScoreResponseData(
219
+ index=idx,
220
+ score=classify_res.outputs.score,
221
+ )
222
+ prompt_token_ids = final_res.prompt_token_ids
223
+
224
+ items.append(item)
225
+ num_prompt_tokens += len(prompt_token_ids)
226
+
227
+ usage = UsageInfo(
228
+ prompt_tokens=num_prompt_tokens,
229
+ total_tokens=num_prompt_tokens,
230
+ )
231
+
232
+ return ScoreResponse(
233
+ id=request_id,
234
+ created=created_time,
235
+ model=model_name,
236
+ data=items,
237
+ usage=usage,
238
+ )
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/serving_tokenization.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Final, List, Optional, Union
4
+
5
+ from fastapi import Request
6
+
7
+ from vllm.config import ModelConfig
8
+ from vllm.engine.protocol import EngineClient
9
+ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
10
+ from vllm.entrypoints.logger import RequestLogger
11
+ # yapf conflicts with isort for this block
12
+ # yapf: disable
13
+ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
14
+ DetokenizeResponse,
15
+ ErrorResponse,
16
+ TokenizeChatRequest,
17
+ TokenizeRequest,
18
+ TokenizeResponse)
19
+ # yapf: enable
20
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
21
+ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
22
+ from vllm.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class OpenAIServingTokenization(OpenAIServing):
28
+
29
+ def __init__(
30
+ self,
31
+ engine_client: EngineClient,
32
+ model_config: ModelConfig,
33
+ models: OpenAIServingModels,
34
+ *,
35
+ request_logger: Optional[RequestLogger],
36
+ chat_template: Optional[str],
37
+ chat_template_content_format: ChatTemplateContentFormatOption,
38
+ ) -> None:
39
+ super().__init__(engine_client=engine_client,
40
+ model_config=model_config,
41
+ models=models,
42
+ request_logger=request_logger)
43
+
44
+ self.chat_template = chat_template
45
+ self.chat_template_content_format: Final = chat_template_content_format
46
+
47
+ async def create_tokenize(
48
+ self,
49
+ request: TokenizeRequest,
50
+ raw_request: Request,
51
+ ) -> Union[TokenizeResponse, ErrorResponse]:
52
+ error_check_ret = await self._check_model(request)
53
+ if error_check_ret is not None:
54
+ return error_check_ret
55
+
56
+ request_id = f"tokn-{self._base_request_id(raw_request)}"
57
+
58
+ try:
59
+ (
60
+ lora_request,
61
+ prompt_adapter_request,
62
+ ) = self._maybe_get_adapters(request)
63
+
64
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
65
+
66
+ if isinstance(request, TokenizeChatRequest):
67
+ (
68
+ _,
69
+ request_prompts,
70
+ engine_prompts,
71
+ ) = await self._preprocess_chat(
72
+ request,
73
+ tokenizer,
74
+ request.messages,
75
+ chat_template=request.chat_template or self.chat_template,
76
+ chat_template_content_format=self.
77
+ chat_template_content_format,
78
+ add_generation_prompt=request.add_generation_prompt,
79
+ continue_final_message=request.continue_final_message,
80
+ chat_template_kwargs=request.chat_template_kwargs,
81
+ add_special_tokens=request.add_special_tokens,
82
+ )
83
+ else:
84
+ (request_prompts,
85
+ engine_prompts) = await self._preprocess_completion(
86
+ request,
87
+ tokenizer,
88
+ request.prompt,
89
+ add_special_tokens=request.add_special_tokens,
90
+ )
91
+ except ValueError as e:
92
+ logger.exception("Error in preprocessing prompt inputs")
93
+ return self.create_error_response(str(e))
94
+
95
+ input_ids: List[int] = []
96
+ for i, engine_prompt in enumerate(engine_prompts):
97
+ self._log_inputs(request_id,
98
+ request_prompts[i],
99
+ params=None,
100
+ lora_request=lora_request,
101
+ prompt_adapter_request=prompt_adapter_request)
102
+
103
+ # Silently ignore prompt adapter since it does not affect
104
+ # tokenization (Unlike in Embeddings API where an error is raised)
105
+
106
+ input_ids.extend(engine_prompt["prompt_token_ids"])
107
+
108
+ return TokenizeResponse(tokens=input_ids,
109
+ count=len(input_ids),
110
+ max_model_len=self.max_model_len)
111
+
112
+ async def create_detokenize(
113
+ self,
114
+ request: DetokenizeRequest,
115
+ raw_request: Request,
116
+ ) -> Union[DetokenizeResponse, ErrorResponse]:
117
+ error_check_ret = await self._check_model(request)
118
+ if error_check_ret is not None:
119
+ return error_check_ret
120
+
121
+ request_id = f"tokn-{self._base_request_id(raw_request)}"
122
+
123
+ (
124
+ lora_request,
125
+ prompt_adapter_request,
126
+ ) = self._maybe_get_adapters(request)
127
+
128
+ tokenizer = await self.engine_client.get_tokenizer(lora_request)
129
+
130
+ self._log_inputs(request_id,
131
+ request.tokens,
132
+ params=None,
133
+ lora_request=lora_request,
134
+ prompt_adapter_request=prompt_adapter_request)
135
+
136
+ # Silently ignore prompt adapter since it does not affect tokenization
137
+ # (Unlike in Embeddings API where an error is raised)
138
+
139
+ prompt_input = await self._tokenize_prompt_input_async(
140
+ request,
141
+ tokenizer,
142
+ request.tokens,
143
+ )
144
+ input_text = prompt_input["prompt"]
145
+
146
+ return DetokenizeResponse(prompt=input_text)
.venv/lib/python3.11/site-packages/vllm/entrypoints/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import functools
5
+
6
+ from fastapi import Request
7
+
8
+
9
+ async def listen_for_disconnect(request: Request) -> None:
10
+ """Returns if a disconnect message is received"""
11
+ while True:
12
+ message = await request.receive()
13
+ if message["type"] == "http.disconnect":
14
+ break
15
+
16
+
17
+ def with_cancellation(handler_func):
18
+ """Decorator that allows a route handler to be cancelled by client
19
+ disconnections.
20
+
21
+ This does _not_ use request.is_disconnected, which does not work with
22
+ middleware. Instead this follows the pattern from
23
+ starlette.StreamingResponse, which simultaneously awaits on two tasks- one
24
+ to wait for an http disconnect message, and the other to do the work that we
25
+ want done. When the first task finishes, the other is cancelled.
26
+
27
+ A core assumption of this method is that the body of the request has already
28
+ been read. This is a safe assumption to make for fastapi handlers that have
29
+ already parsed the body of the request into a pydantic model for us.
30
+ This decorator is unsafe to use elsewhere, as it will consume and throw away
31
+ all incoming messages for the request while it looks for a disconnect
32
+ message.
33
+
34
+ In the case where a `StreamingResponse` is returned by the handler, this
35
+ wrapper will stop listening for disconnects and instead the response object
36
+ will start listening for disconnects.
37
+ """
38
+
39
+ # Functools.wraps is required for this wrapper to appear to fastapi as a
40
+ # normal route handler, with the correct request type hinting.
41
+ @functools.wraps(handler_func)
42
+ async def wrapper(*args, **kwargs):
43
+
44
+ # The request is either the second positional arg or `raw_request`
45
+ request = args[1] if len(args) > 1 else kwargs["raw_request"]
46
+
47
+ handler_task = asyncio.create_task(handler_func(*args, **kwargs))
48
+ cancellation_task = asyncio.create_task(listen_for_disconnect(request))
49
+
50
+ done, pending = await asyncio.wait([handler_task, cancellation_task],
51
+ return_when=asyncio.FIRST_COMPLETED)
52
+ for task in pending:
53
+ task.cancel()
54
+
55
+ if handler_task in done:
56
+ return handler_task.result()
57
+ return None
58
+
59
+ return wrapper
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/audio.cpython-311.pyc ADDED
Binary file (4.73 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/base.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/hasher.cpython-311.pyc ADDED
Binary file (5.65 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/parse.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/processing.cpython-311.pyc ADDED
Binary file (60.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/registry.cpython-311.pyc ADDED
Binary file (22.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/video.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/multimodal/audio.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import base64
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ from vllm.inputs.registry import InputContext
11
+ from vllm.utils import PlaceholderModule
12
+
13
+ from .base import MediaIO, MultiModalPlugin
14
+ from .inputs import AudioItem, ModalityData, MultiModalKwargs
15
+
16
+ try:
17
+ import librosa
18
+ except ImportError:
19
+ librosa = PlaceholderModule("librosa") # type: ignore[assignment]
20
+
21
+ try:
22
+ import soundfile
23
+ except ImportError:
24
+ soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
25
+
26
+
27
+ class AudioPlugin(MultiModalPlugin):
28
+ """Plugin for audio data."""
29
+
30
+ def get_data_key(self) -> str:
31
+ return "audio"
32
+
33
+ def _default_input_mapper(
34
+ self,
35
+ ctx: InputContext,
36
+ data: ModalityData[AudioItem],
37
+ **mm_processor_kwargs,
38
+ ) -> MultiModalKwargs:
39
+ raise NotImplementedError("There is no default audio input mapper")
40
+
41
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
42
+ raise NotImplementedError(
43
+ "There is no default maximum multimodal tokens")
44
+
45
+
46
+ def resample_audio(
47
+ audio: npt.NDArray[np.floating],
48
+ *,
49
+ orig_sr: float,
50
+ target_sr: float,
51
+ ) -> npt.NDArray[np.floating]:
52
+ return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
53
+
54
+
55
+ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
56
+
57
+ def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
58
+ return librosa.load(BytesIO(data), sr=None)
59
+
60
+ def load_base64(
61
+ self,
62
+ media_type: str,
63
+ data: str,
64
+ ) -> tuple[npt.NDArray, float]:
65
+ return self.load_bytes(base64.b64decode(data))
66
+
67
+ def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
68
+ return librosa.load(filepath, sr=None)
69
+
70
+ def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
71
+ audio, sr = media
72
+
73
+ with BytesIO() as buffer:
74
+ soundfile.write(buffer, audio, sr, format="WAV")
75
+ data = buffer.getvalue()
76
+
77
+ return base64.b64encode(data).decode('utf-8')
.venv/lib/python3.11/site-packages/vllm/multimodal/base.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+ from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
7
+ Optional, Sequence, Tuple, Type, TypeVar, Union)
8
+
9
+ from torch import nn
10
+
11
+ from vllm.inputs import InputContext
12
+ from vllm.logger import init_logger
13
+ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
14
+ resolve_mm_processor_kwargs)
15
+
16
+ if TYPE_CHECKING:
17
+ from vllm.config import ModelConfig
18
+ from vllm.sequence import SequenceGroupMetadata
19
+
20
+ from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
21
+ PlaceholderRange)
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
26
+ MultiModalKwargs]
27
+ """
28
+ Return a dictionary to be passed as keyword arguments to
29
+ :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
30
+ and processors in HuggingFace Transformers.
31
+
32
+ If the data is not supported, throw :exc:`TypeError`.
33
+ """
34
+
35
+ MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
36
+ """
37
+ Calculate the maximum number of multimodal tokens input to the language
38
+ model. This does not include tokens that correspond to the input text.
39
+ """
40
+
41
+ _T = TypeVar("_T")
42
+ N = TypeVar("N", bound=Type[nn.Module])
43
+
44
+
45
+ class MultiModalPlugin(ABC):
46
+ """
47
+ Base class that defines data processing logic for a specific modality.
48
+
49
+ In particular, we adopt a registry pattern to dispatch data processing
50
+ according to the model being used (considering that different models may
51
+ process the same data differently). This registry is in turn used by
52
+ :class:`~MultiModalRegistry` which acts at a higher level
53
+ (i.e., the modality of the data).
54
+ """
55
+
56
+ def __init__(self) -> None:
57
+ self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
58
+ self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
59
+
60
+ @abstractmethod
61
+ def get_data_key(self) -> str:
62
+ """
63
+ Get the data key corresponding to the modality.
64
+ """
65
+ raise NotImplementedError
66
+
67
+ @abstractmethod
68
+ def _default_input_mapper(
69
+ self,
70
+ ctx: InputContext,
71
+ data: ModalityData[Any],
72
+ **mm_processor_kwargs,
73
+ ) -> MultiModalKwargs:
74
+ """
75
+ Return a dictionary to be passed as keyword arguments to
76
+ :meth:`~torch.nn.Module.forward`. This is similar in concept to
77
+ tokenizers and processors in HuggingFace Transformers.
78
+
79
+ If the data is not supported, throw :exc:`TypeError`.
80
+ """
81
+ raise NotImplementedError
82
+
83
+ def register_input_mapper(
84
+ self,
85
+ mapper: Optional[MultiModalInputMapper] = None,
86
+ ):
87
+ """
88
+ Register an input mapper to a model class.
89
+
90
+ When the model receives input data that matches the modality served by
91
+ this plugin (see :meth:`get_data_key`), the provided function is
92
+ invoked to transform the data into a dictionary of model inputs.
93
+
94
+ If `None` is provided, then the default input mapper is used instead.
95
+ """
96
+
97
+ def wrapper(model_cls: N) -> N:
98
+ if self._input_mappers.contains(model_cls, strict=True):
99
+ logger.warning(
100
+ "Model class %s already has an input mapper "
101
+ "registered to %s. It is overwritten by the new one.",
102
+ model_cls,
103
+ self,
104
+ )
105
+
106
+ self._input_mappers[model_cls] = (mapper
107
+ or self._default_input_mapper)
108
+
109
+ return model_cls
110
+
111
+ return wrapper
112
+
113
+ def map_input(
114
+ self,
115
+ model_config: "ModelConfig",
116
+ data: ModalityData[Any],
117
+ mm_processor_kwargs: Optional[dict[str, Any]],
118
+ ) -> MultiModalKwargs:
119
+ """
120
+ Transform the data into a dictionary of model inputs using the
121
+ input mapper registered for that model.
122
+
123
+ The model is identified by ``model_config``.
124
+
125
+ Raises:
126
+ TypeError: If the data type is not supported.
127
+ """
128
+
129
+ # Avoid circular import
130
+ from vllm.model_executor.model_loader import get_model_architecture
131
+
132
+ model_cls, _ = get_model_architecture(model_config)
133
+
134
+ mapper = self._input_mappers.get(model_cls)
135
+
136
+ if mapper is None:
137
+ raise KeyError(f"No input mapper in {self} is registered for "
138
+ f"model class {model_cls.__name__}.")
139
+
140
+ if mm_processor_kwargs is None:
141
+ mm_processor_kwargs = {}
142
+
143
+ # In the case of the default mapper, we have to get resource
144
+ # processor through its HuggingFace autoclass; since this goes
145
+ # through **kwargs, we can't inspect it the same way, so we allow
146
+ # drop mm_processor_kwargs based on signature inspection
147
+ # if we're using the default mapper.
148
+ #
149
+ # This should be safe in general due to the sanitation, since the
150
+ # transformers resource should filter unused kwargs anyway.
151
+ uses_default_mapper = mapper == self._default_input_mapper
152
+ mm_processor_kwargs = resolve_mm_processor_kwargs(
153
+ model_config.mm_processor_kwargs,
154
+ mm_processor_kwargs,
155
+ callable=mapper,
156
+ allow_var_kwargs=uses_default_mapper,
157
+ )
158
+ return mapper(InputContext(model_config), data, **mm_processor_kwargs)
159
+
160
+ @abstractmethod
161
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
162
+ """
163
+ Calculate the maximum number of tokens, corresponding to a single
164
+ instance of multimodal data, that are passed to the language model.
165
+ """
166
+ raise NotImplementedError
167
+
168
+ def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
169
+ if max_mm_tokens < 1:
170
+ raise ValueError("You should set the number of tokens to a "
171
+ f"positive integer. Found: {max_mm_tokens}")
172
+
173
+ def register_max_multimodal_tokens(
174
+ self,
175
+ max_mm_tokens: Optional[MultiModalTokensCalc] = None,
176
+ ):
177
+ """
178
+ Register the maximum number of tokens, corresponding to a single
179
+ instance of multimodal data, that are passed to the language model
180
+ for a model class.
181
+
182
+ If `None` is provided, then the default calculation is used instead.
183
+ """
184
+
185
+ def wrapper(model_cls: N) -> N:
186
+ if self._max_mm_tokens.contains(model_cls, strict=True):
187
+ logger.warning(
188
+ "Model class %s already calculates maximum number of "
189
+ "tokens in %s. It is overwritten by the new one.",
190
+ model_cls,
191
+ self,
192
+ )
193
+
194
+ if isinstance(max_mm_tokens, int):
195
+ self._validate_max_multimodal_tokens(max_mm_tokens)
196
+
197
+ self._max_mm_tokens[model_cls] = (
198
+ max_mm_tokens or self._default_max_multimodal_tokens)
199
+
200
+ return model_cls
201
+
202
+ return wrapper
203
+
204
+ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
205
+ """
206
+ Get the maximum number of multi-modal tokens
207
+ for profiling the memory usage of a model.
208
+
209
+ If this registry is not applicable to the model, `0` is returned.
210
+
211
+ The model is identified by ``model_config``.
212
+ """
213
+ # Avoid circular import
214
+ from vllm.model_executor.model_loader import get_model_architecture
215
+ from vllm.model_executor.models import supports_multimodal
216
+
217
+ model_cls, _ = get_model_architecture(model_config)
218
+
219
+ if not supports_multimodal(model_cls):
220
+ return 0
221
+
222
+ max_mm_tokens = self._max_mm_tokens.get(model_cls)
223
+ if max_mm_tokens is None:
224
+ return 0
225
+
226
+ if callable(max_mm_tokens):
227
+ mm_processor_kwargs = get_allowed_kwarg_only_overrides(
228
+ max_mm_tokens, overrides=model_config.mm_processor_kwargs)
229
+ max_mm_tokens = max_mm_tokens(InputContext(model_config),
230
+ **mm_processor_kwargs)
231
+
232
+ self._validate_max_multimodal_tokens(max_mm_tokens)
233
+
234
+ return max_mm_tokens
235
+
236
+
237
+ class MultiModalPlaceholderMap:
238
+ """
239
+ Relates multi-modal embeddings to their corresponding placeholders.
240
+ """
241
+
242
+ class IndexMap(NamedTuple):
243
+ src: list[int]
244
+ dest: list[int]
245
+
246
+ src_ranges: list[range]
247
+ """
248
+ The indices of the multi-modal embeddings that will replace the
249
+ corresponding placeholder embeddings pointed to by ``dest_ranges``.
250
+ """
251
+
252
+ src_len: int
253
+ """
254
+ The total number of flattened multi-modal embeddings.
255
+ """
256
+
257
+ dest_ranges: list[range]
258
+ """
259
+ The indices of the placeholder embeddings that will be replaced by the
260
+ multimodal embeddings.
261
+ """
262
+
263
+ dest_len: int
264
+ """
265
+ The total number of embeddings in the destination tensor.
266
+ """
267
+
268
+ def __init__(self):
269
+ self.src_ranges = []
270
+ self.src_len = 0
271
+ self.dest_ranges = []
272
+ self.dest_len = 0
273
+
274
+ @classmethod
275
+ def from_seq_group(
276
+ cls, seq_group: "SequenceGroupMetadata", positions: range
277
+ ) -> Tuple[Optional[MultiModalDataDict], dict[str,
278
+ "MultiModalPlaceholderMap"]]:
279
+ """
280
+ Returns the multi-modal items that intersect with the portion of a
281
+ prompt (``seq_group``) represented by ``positions``, as well as a
282
+ ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
283
+ vectors to their corresponding placeholders.
284
+
285
+ Examples:
286
+
287
+ .. code-block::
288
+
289
+ Prompt: |AAAA BBBB What's in these images?|
290
+ Positions: |.................................|
291
+
292
+ images = [A, B]
293
+ src_ranges = [(0, 4), (4, 8)]
294
+ dest_ranges = [(0, 4), (5, 9)]
295
+
296
+ Prompt: |AAAA BBBB What's in these images?|
297
+ Positions: | ..... |
298
+
299
+ images = [A, B]
300
+ src_ranges = [(2, 4), (4, 6)]
301
+ dest_ranges = [(0, 2), (3, 5)]
302
+
303
+ Prompt: |AAAA BBBB What's in these images?|
304
+ Positions: | ......... |
305
+
306
+ images = [B]
307
+ src_ranges = [(0, 4)]
308
+ dest_ranges = [(0, 4)]
309
+
310
+ Prompt: |AAAA BBBB What's in these images?|
311
+ Positions: | .......................|
312
+
313
+ images = []
314
+ src_ranges = []
315
+ dest_ranges = []
316
+ """
317
+ seq_mm_data = seq_group.multi_modal_data
318
+ seq_mm_placeholders = seq_group.multi_modal_placeholders
319
+
320
+ if not seq_mm_data or not seq_mm_placeholders:
321
+ return seq_mm_data, {}
322
+
323
+ # For merged processor, we directly use mm_kwargs as mm_data
324
+ if isinstance(seq_mm_data, MultiModalKwargs):
325
+ placeholder_maps = dict[str, MultiModalPlaceholderMap]()
326
+
327
+ for modality, placeholders in seq_mm_placeholders.items():
328
+ placeholder_map = MultiModalPlaceholderMap()
329
+
330
+ if positions:
331
+ placeholder_map.append_items_from_seq_group(
332
+ positions,
333
+ # Dummy, since we don't care about intersecting items
334
+ [None] * len(placeholders),
335
+ placeholders,
336
+ )
337
+
338
+ placeholder_maps[modality] = placeholder_map
339
+
340
+ return seq_mm_data, placeholder_maps
341
+
342
+ mm_data = {**seq_mm_data}
343
+ placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
344
+ MultiModalPlaceholderMap)
345
+
346
+ for modality, placeholders in seq_mm_placeholders.items():
347
+ mm_items = mm_data.pop(modality)
348
+ if not isinstance(mm_items, list):
349
+ mm_items = [mm_items]
350
+
351
+ if positions:
352
+ intersecting_items = placeholder_maps[modality] \
353
+ .append_items_from_seq_group(
354
+ positions,
355
+ mm_items,
356
+ placeholders,
357
+ )
358
+
359
+ if intersecting_items:
360
+ mm_data[modality] = intersecting_items
361
+
362
+ return mm_data, placeholder_maps
363
+
364
+ def append_items_from_seq_group(
365
+ self,
366
+ positions: range,
367
+ multi_modal_items: list[_T],
368
+ multi_modal_placeholders: Sequence[PlaceholderRange],
369
+ ) -> list[_T]:
370
+ """
371
+ Adds the multi-modal items that intersect ```positions`` to this
372
+ placeholder map and returns the intersecting items.
373
+ """
374
+ intersecting_items = []
375
+
376
+ if len(multi_modal_items) != len(multi_modal_placeholders):
377
+ raise ValueError(
378
+ "Multi-modal placeholders and items must have the same length."
379
+ )
380
+ for placeholder_dict, mm_item in zip(multi_modal_placeholders,
381
+ multi_modal_items):
382
+ placeholder = range(
383
+ placeholder_dict["offset"],
384
+ placeholder_dict["offset"] + placeholder_dict["length"],
385
+ )
386
+ intersection = range(
387
+ max(positions.start, placeholder.start),
388
+ min(positions.stop, placeholder.stop),
389
+ )
390
+
391
+ if not intersection:
392
+ # Skip this multi-modal item.
393
+ continue
394
+
395
+ token_embedding_range = range(
396
+ intersection.start - positions.start,
397
+ intersection.stop - positions.start,
398
+ )
399
+
400
+ multimodal_embedding_range = range(
401
+ intersection.start - placeholder.start + self.src_len,
402
+ intersection.stop - placeholder.start + self.src_len,
403
+ )
404
+
405
+ intersecting_items.append(mm_item)
406
+ self.dest_ranges.append(token_embedding_range)
407
+ self.src_ranges.append(multimodal_embedding_range)
408
+ self.src_len += len(placeholder)
409
+
410
+ self.dest_len += len(positions)
411
+ return intersecting_items
412
+
413
+ def extend(self, other: "MultiModalPlaceholderMap"):
414
+ """
415
+ Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
416
+ instance based on the source and destination tensors being
417
+ concatenated.
418
+ """
419
+
420
+ self.src_ranges.extend(
421
+ range(self.src_len + r.start, self.src_len + r.stop)
422
+ for r in other.src_ranges)
423
+ self.src_len += other.src_len
424
+ self.dest_ranges.extend(
425
+ range(self.dest_len + r.start, self.dest_len + r.stop)
426
+ for r in other.dest_ranges)
427
+ self.dest_len += other.dest_len
428
+
429
+ def index_map(self) -> "IndexMap":
430
+ """
431
+ Finalizes the placeholder map into lists of indices that can be used to
432
+ index the source and destination tensors.
433
+ """
434
+
435
+ src_indices = [i for r in self.src_ranges for i in r]
436
+ dest_indices = [i for r in self.dest_ranges for i in r]
437
+
438
+ if len(src_indices) != len(dest_indices):
439
+ raise ValueError(
440
+ f"The number of source ({len(src_indices)}) and destination "
441
+ f"indices ({len(dest_indices)}) must be the same.")
442
+
443
+ return MultiModalPlaceholderMap.IndexMap(src=src_indices,
444
+ dest=dest_indices)
445
+
446
+
447
+ class MediaIO(ABC, Generic[_T]):
448
+
449
+ @abstractmethod
450
+ def load_bytes(self, data: bytes) -> _T:
451
+ raise NotImplementedError
452
+
453
+ @abstractmethod
454
+ def load_base64(self, media_type: str, data: str) -> _T:
455
+ """
456
+ List of media types:
457
+ https://www.iana.org/assignments/media-types/media-types.xhtml
458
+ """
459
+ raise NotImplementedError
460
+
461
+ @abstractmethod
462
+ def load_file(self, filepath: Path) -> _T:
463
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/plugins/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.56 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/spec_decode/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/spec_decode/__pycache__/batch_expansion.cpython-311.pyc ADDED
Binary file (21.7 kB). View file