koichi12 commited on
Commit
0ae5c3e
·
verified ·
1 Parent(s): f0ca319

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/compilation/backends.py +874 -0
  2. .venv/lib/python3.11/site-packages/vllm/compilation/counter.py +33 -0
  3. .venv/lib/python3.11/site-packages/vllm/compilation/decorators.py +249 -0
  4. .venv/lib/python3.11/site-packages/vllm/compilation/fix_functionalization.py +182 -0
  5. .venv/lib/python3.11/site-packages/vllm/compilation/fusion.py +617 -0
  6. .venv/lib/python3.11/site-packages/vllm/compilation/fx_utils.py +44 -0
  7. .venv/lib/python3.11/site-packages/vllm/compilation/monitor.py +38 -0
  8. .venv/lib/python3.11/site-packages/vllm/compilation/pass_manager.py +79 -0
  9. .venv/lib/python3.11/site-packages/vllm/compilation/wrapper.py +129 -0
  10. .venv/lib/python3.11/site-packages/vllm/engine/__init__.py +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/__init__.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/arg_utils.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_llm_engine.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_timeout.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/llm_engine.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics_types.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/protocol.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/vllm/engine/arg_utils.py +1360 -0
  20. .venv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py +1198 -0
  21. .venv/lib/python3.11/site-packages/vllm/engine/async_timeout.py +191 -0
  22. .venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py +2025 -0
  23. .venv/lib/python3.11/site-packages/vllm/engine/metrics.py +681 -0
  24. .venv/lib/python3.11/site-packages/vllm/engine/metrics_types.py +102 -0
  25. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__init__.py +159 -0
  26. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/client.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/engine.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/client.py +707 -0
  30. .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py +391 -0
  31. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__init__.py +0 -0
  32. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/interfaces.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/multi_step.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/single_step.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/stop_checker.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/util.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/interfaces.py +74 -0
  39. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/multi_step.py +205 -0
  40. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/single_step.py +136 -0
  41. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/stop_checker.py +130 -0
  42. .venv/lib/python3.11/site-packages/vllm/engine/output_processor/util.py +27 -0
  43. .venv/lib/python3.11/site-packages/vllm/engine/protocol.py +284 -0
  44. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__init__.py +48 -0
  45. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/__init__.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/arctic.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/cohere2.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/compilation/backends.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import ast
4
+ import copy
5
+ import dataclasses
6
+ import os
7
+ import pprint
8
+ import time
9
+ from collections import defaultdict
10
+ from contextlib import ExitStack
11
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
12
+ from unittest.mock import patch
13
+
14
+ import torch
15
+ import torch.fx as fx
16
+
17
+ import vllm.envs as envs
18
+ from vllm.config import CompilationConfig, VllmConfig
19
+ from vllm.logger import init_logger
20
+ from vllm.utils import weak_ref_tensors
21
+
22
+ from .counter import compilation_counter
23
+ from .inductor_pass import InductorPass
24
+ from .monitor import end_monitoring_torch_compile
25
+ from .pass_manager import PostGradPassManager
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class InductorArtifact:
32
+ hash_str: str = ""
33
+ file_path: str = ""
34
+
35
+
36
+ class InductorHashCache:
37
+ """
38
+ Disk format: a Python list of tuples, each tuple is
39
+ (runtime_shape, graph_index, hash_str, file_path)
40
+ We use list of tuple for readability.
41
+
42
+ In-memory format: a defaultdict of dict, where the key is
43
+ runtime_shape, and the value is a dict of graph_index to hash_str.
44
+
45
+ The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
46
+ we don't use json here because json doesn't support int as key.
47
+
48
+ TODO: better off-the-shelf solution to serialize the data?
49
+ """
50
+
51
+ def __init__(self, cache_dir: str, disabled: bool = False):
52
+ self.cache: Dict[Optional[int],
53
+ Dict[int, InductorArtifact]] = defaultdict(dict)
54
+ self.disabled = disabled
55
+ self.cache_dir = cache_dir
56
+ self.cache_file_path = os.path.join(cache_dir,
57
+ "inductor_hash_cache.py")
58
+ if disabled:
59
+ return
60
+ # set flags so that Inductor and Triton store their cache
61
+ # in the cache_dir, then users only need to copy the cache_dir
62
+ # to another machine to reuse the cache.
63
+ inductor_cache = os.path.join(cache_dir, "inductor_cache")
64
+ os.makedirs(inductor_cache, exist_ok=True)
65
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
66
+ triton_cache = os.path.join(cache_dir, "triton_cache")
67
+ os.makedirs(triton_cache, exist_ok=True)
68
+ os.environ["TRITON_CACHE_DIR"] = triton_cache
69
+ if os.path.exists(self.cache_file_path):
70
+ with open(self.cache_file_path) as f:
71
+ self.deserialize(f.read())
72
+
73
+ def deserialize(self, data: str):
74
+ # we use ast.literal_eval to parse the data
75
+ # because it is a safe way to parse Python literals.
76
+ # do not use eval(), it is unsafe.
77
+ list_data = ast.literal_eval(data)
78
+ for item in list_data:
79
+ runtime_shape = item[0]
80
+ graph_index = item[1]
81
+ hash_str = item[2]
82
+ # for compatibility of old version,
83
+ # where we don't have file_path.
84
+ # NOTE: after running the new code, the file_path
85
+ # will be updated.
86
+ file_path = "" if len(item) == 3 else item[3]
87
+ self.cache[runtime_shape][graph_index] = InductorArtifact(
88
+ hash_str=hash_str, file_path=file_path)
89
+
90
+ def serialize(self) -> str:
91
+ data = []
92
+ for runtime_shape, value in self.cache.items():
93
+ for graph_index, inductor_artifact in value.items():
94
+ data.append(
95
+ (runtime_shape, graph_index, inductor_artifact.hash_str,
96
+ inductor_artifact.file_path))
97
+ printer = pprint.PrettyPrinter(indent=4)
98
+ return printer.pformat(data)
99
+
100
+ def save_to_file(self):
101
+ if self.disabled:
102
+ return
103
+ with open(self.cache_file_path, "w") as f:
104
+ f.write(self.serialize())
105
+
106
+ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
107
+ if self.disabled:
108
+ return False
109
+ runtime_shape, graph_index = key
110
+ return runtime_shape in self.cache and graph_index in self.cache[
111
+ runtime_shape]
112
+
113
+ def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
114
+ if self.disabled:
115
+ raise KeyError("cannot read from disabled cache")
116
+ runtime_shape, graph_index = key
117
+ return self.cache[runtime_shape][graph_index]
118
+
119
+ def __setitem__(self, key: Tuple[Optional[int], int],
120
+ value: InductorArtifact):
121
+ # setitem for disabled cache is fine, because we
122
+ # don't actually write to the disk
123
+ runtime_shape, graph_index = key
124
+ self.cache[runtime_shape][graph_index] = value
125
+
126
+
127
+ class AlwaysHitShapeEnv:
128
+ """
129
+ Why do we need this class:
130
+
131
+ For normal `torch.compile` usage, every compilation will have
132
+ one Dynamo bytecode compilation and one Inductor compilation.
133
+ The Inductor compilation happens under the context of the
134
+ Dynamo bytecode compilation, and that context is used to
135
+ determine the dynamic shape information, etc.
136
+
137
+ For our use case, we only run Dynamo bytecode compilation once,
138
+ and run Inductor compilation multiple times with different shapes
139
+ plus a general shape. The compilation for specific shapes happens
140
+ outside of the context of the Dynamo bytecode compilation. At that
141
+ time, we don't have shape environment to provide to Inductor, and
142
+ it will fail the Inductor code cache lookup.
143
+
144
+ By providing a dummy shape environment that always hits, we can
145
+ make the Inductor code cache lookup always hit, and we can
146
+ compile the graph for different shapes as needed.
147
+
148
+ The following dummy methods are obtained by trial-and-error
149
+ until it works.
150
+ """
151
+
152
+ def __init__(self) -> None:
153
+ self.guards: List[Any] = []
154
+
155
+ def evaluate_guards_expression(self, *args, **kwargs):
156
+ return True
157
+
158
+ def get_pruned_guards(self, *args, **kwargs):
159
+ return []
160
+
161
+ def produce_guards_expression(self, *args, **kwargs):
162
+ return ""
163
+
164
+
165
+ def wrap_inductor(graph: fx.GraphModule,
166
+ example_inputs,
167
+ additional_inductor_config,
168
+ compilation_config: CompilationConfig,
169
+ vllm_backend: "VllmBackend",
170
+ graph_index: int = 0,
171
+ num_graphs: int = 1,
172
+ runtime_shape: Optional[int] = None,
173
+ use_inductor: bool = True) -> Any:
174
+ if graph_index == 0:
175
+ # before compiling the first graph, record the start time
176
+ global compilation_start_time
177
+ compilation_start_time = time.time()
178
+
179
+ if not use_inductor:
180
+ return graph
181
+
182
+ compilation_counter.num_inductor_compilations += 1
183
+
184
+ from torch._inductor import config
185
+ current_config = config.get_config_copy()
186
+ from torch._inductor.compile_fx import compile_fx
187
+
188
+ if additional_inductor_config is not None:
189
+ current_config.update(additional_inductor_config)
190
+
191
+ if isinstance(runtime_shape, int):
192
+ # for a specific batchsize, tuning triton kernel parameters
193
+ # can be beneficial
194
+ current_config["max_autotune"] = True
195
+ current_config["coordinate_descent_tuning"] = True
196
+
197
+ # inductor can inplace modify the graph, so we need to copy it
198
+ # see https://github.com/pytorch/pytorch/issues/138980
199
+ graph = copy.deepcopy(graph)
200
+
201
+ cache_data = vllm_backend.inductor_hash_cache
202
+ if (runtime_shape, graph_index) in cache_data:
203
+ # we compiled this graph before
204
+ # so we can directly lookup the compiled graph via hash
205
+ inductor_artifact = cache_data[(runtime_shape, graph_index)]
206
+ hash_str = inductor_artifact.hash_str
207
+ if graph_index == 0:
208
+ # adds some info logging for the first graph
209
+ logger.info(
210
+ "Directly lookup the graph for shape %s from the cache",
211
+ str(runtime_shape)) # noqa
212
+ logger.debug(
213
+ "directly lookup the %s-th graph for shape %s via hash %s",
214
+ graph_index, str(runtime_shape), hash_str)
215
+ from torch._inductor.codecache import FxGraphCache
216
+ with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
217
+ lambda *args, **kwargs: AlwaysHitShapeEnv()):
218
+ inductor_compiled_graph = FxGraphCache._lookup_graph(
219
+ hash_str, example_inputs, True, False)
220
+ assert inductor_compiled_graph is not None, (
221
+ "Inductor cache lookup failed. Please remove"
222
+ f"the cache file {cache_data.cache_file_path} and try again." # noqa
223
+ )
224
+ inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
225
+
226
+ # Inductor calling convention (function signature):
227
+ # f(list) -> tuple
228
+ # Dynamo calling convention (function signature):
229
+ # f(*args) -> Any
230
+
231
+ # need to know if the graph returns a tuple
232
+ from torch._inductor.compile_fx import graph_returns_tuple
233
+ returns_tuple = graph_returns_tuple(graph)
234
+
235
+ # this is the callable we return to Dynamo to run
236
+ def compiled_graph(*args):
237
+ # convert args to list
238
+ list_args = list(args)
239
+ graph_output = inductor_compiled_graph(list_args)
240
+ # unpack the tuple if needed
241
+ if returns_tuple:
242
+ return graph_output
243
+ else:
244
+ return graph_output[0]
245
+ else:
246
+ # it's the first time we compile this graph
247
+ # the assumption is that we don't have nested Inductor compilation.
248
+ # compiled_fx_graph_hash will only be called once, and we can hook
249
+ # it to get the hash of the compiled graph directly.
250
+
251
+ inductor_artifact = InductorArtifact()
252
+ from torch._inductor.codecache import (FxGraphCache,
253
+ compiled_fx_graph_hash)
254
+ original_load = FxGraphCache.load
255
+
256
+ def hijack_load(*args, **kwargs):
257
+ inductor_compiled_graph = original_load(*args, **kwargs)
258
+ inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
259
+ return inductor_compiled_graph
260
+
261
+ def hijack_compiled_fx_graph_hash(*args, **kwargs):
262
+ out = compiled_fx_graph_hash(*args, **kwargs)
263
+ inductor_artifact.hash_str = out[0]
264
+ return out
265
+
266
+ def _check_can_cache(*args, **kwargs):
267
+ # no error means it can be cached.
268
+ # Inductor refuses to cache the graph outside of Dynamo
269
+ # tracing context, and also disables caching for graphs
270
+ # with high-order ops.
271
+ # For vLLM, in either case, we want to cache the graph.
272
+ # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
273
+ return
274
+
275
+ def _get_shape_env() -> AlwaysHitShapeEnv:
276
+ return AlwaysHitShapeEnv()
277
+
278
+ with ExitStack() as stack:
279
+ if not cache_data.disabled:
280
+ # compilation cache is enabled, patch several functions
281
+
282
+ # hijack to get the compiled graph itself
283
+ stack.enter_context(
284
+ patch("torch._inductor.codecache.FxGraphCache.load",
285
+ hijack_load))
286
+
287
+ # for hijacking the hash of the compiled graph
288
+ stack.enter_context(
289
+ patch("torch._inductor.codecache.compiled_fx_graph_hash",
290
+ hijack_compiled_fx_graph_hash))
291
+
292
+ # for providing a dummy shape environment
293
+ stack.enter_context(
294
+ patch(
295
+ "torch._inductor.codecache.FxGraphCache._get_shape_env",
296
+ _get_shape_env))
297
+
298
+ # for forcing the graph to be cached
299
+ stack.enter_context(
300
+ patch(
301
+ "torch._inductor.codecache.FxGraphCache._check_can_cache",
302
+ _check_can_cache))
303
+
304
+ compiled_graph = compile_fx(graph,
305
+ example_inputs,
306
+ config_patches=current_config)
307
+ # store the inductor_artifact in the cache
308
+ cache_data[(runtime_shape, graph_index)] = inductor_artifact
309
+ if graph_index == 0:
310
+ # adds some info logging for the first graph
311
+ logger.info("Cache the graph of shape %s for later use",
312
+ str(runtime_shape))
313
+ logger.debug(
314
+ "store the %s-th graph for shape %s via hash %s from file %s",
315
+ graph_index, str(runtime_shape), inductor_artifact.hash_str,
316
+ inductor_artifact.file_path)
317
+ # after compiling the last graph, record the end time
318
+ if graph_index == num_graphs - 1:
319
+ now = time.time()
320
+ elapsed = now - compilation_start_time
321
+ compilation_config.compilation_time += elapsed
322
+ if runtime_shape is None:
323
+ logger.info("Compiling a graph for general shape takes %.2f s",
324
+ elapsed)
325
+ else:
326
+ logger.info("Compiling a graph for shape %s takes %.2f s",
327
+ runtime_shape, elapsed)
328
+
329
+ return compiled_graph
330
+
331
+
332
+ @dataclasses.dataclass
333
+ class SplitItem:
334
+ submod_name: str
335
+ graph_id: int
336
+ is_splitting_graph: bool
337
+ graph: fx.GraphModule
338
+
339
+
340
+ def split_graph(graph: fx.GraphModule,
341
+ ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
342
+ # split graph by ops
343
+ subgraph_id = 0
344
+ node_to_subgraph_id = {}
345
+ split_op_graphs = []
346
+ for node in graph.graph.nodes:
347
+ if node.op in ("output", "placeholder"):
348
+ continue
349
+ if node.op == 'call_function' and str(node.target) in ops:
350
+ subgraph_id += 1
351
+ node_to_subgraph_id[node] = subgraph_id
352
+ split_op_graphs.append(subgraph_id)
353
+ subgraph_id += 1
354
+ else:
355
+ node_to_subgraph_id[node] = subgraph_id
356
+
357
+ # `keep_original_order` is important!
358
+ # otherwise pytorch might reorder the nodes and
359
+ # the semantics of the graph will change when we
360
+ # have mutations in the graph
361
+ split_gm = torch.fx.passes.split_module.split_module(
362
+ graph,
363
+ None,
364
+ lambda node: node_to_subgraph_id[node],
365
+ keep_original_order=True)
366
+
367
+ outputs = []
368
+
369
+ names = [name for (name, module) in split_gm.named_modules()]
370
+
371
+ for name in names:
372
+ if "." in name or name == "":
373
+ # recursive child module or the root module
374
+ continue
375
+
376
+ module = getattr(split_gm, name)
377
+
378
+ graph_id = int(name.replace("submod_", ""))
379
+ outputs.append(
380
+ SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
381
+
382
+ # sort by intetger graph_id, rather than string name
383
+ outputs.sort(key=lambda x: x.graph_id)
384
+
385
+ return split_gm, outputs
386
+
387
+
388
+ # we share the global graph pool among all the backends
389
+ global_graph_pool = None
390
+
391
+ compilation_start_time = 0.0
392
+
393
+
394
+ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
395
+ """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
396
+ It runs the given graph with fake inputs, and compile some
397
+ submodules specified by `compile_submod_names` with the given
398
+ compilation configs.
399
+
400
+ NOTE: the order in `compile_submod_names` matters, because
401
+ it will be used to determine the order of the compiled piecewise
402
+ graphs. The first graph will handle logging, and the last graph
403
+ has some special cudagraph output handling.
404
+ """
405
+
406
+ def __init__(self, module: torch.fx.GraphModule,
407
+ compile_submod_names: List[str], vllm_config: VllmConfig,
408
+ graph_pool, vllm_backend: "VllmBackend"):
409
+ super().__init__(module)
410
+ from torch._guards import detect_fake_mode
411
+ self.fake_mode = detect_fake_mode()
412
+ self.compile_submod_names = compile_submod_names
413
+ self.compilation_config = vllm_config.compilation_config
414
+ self.graph_pool = graph_pool
415
+ self.vllm_config = vllm_config
416
+ self.vllm_backend = vllm_backend
417
+
418
+ def run(self, *args):
419
+ fake_args = [
420
+ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
421
+ for t in args
422
+ ]
423
+ with self.fake_mode:
424
+ return super().run(*fake_args)
425
+
426
+ def call_module(self, target: torch.fx.node.Target,
427
+ args: Tuple[torch.fx.node.Argument,
428
+ ...], kwargs: Dict[str, Any]) -> Any:
429
+ assert isinstance(target, str)
430
+ output = super().call_module(target, args, kwargs)
431
+
432
+ if target in self.compile_submod_names:
433
+ index = self.compile_submod_names.index(target)
434
+ submod = self.fetch_attr(target)
435
+ sym_shape_indices = [
436
+ i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
437
+ ]
438
+ global compilation_start_time
439
+ compiled_graph_for_general_shape = wrap_inductor(
440
+ submod,
441
+ args,
442
+ self.compilation_config.inductor_compile_config,
443
+ self.compilation_config,
444
+ self.vllm_backend,
445
+ graph_index=index,
446
+ num_graphs=len(self.compile_submod_names),
447
+ runtime_shape=None,
448
+ use_inductor=self.compilation_config.use_inductor)
449
+
450
+ self.module.__dict__[target] = PiecewiseBackend(
451
+ submod, self.vllm_config, self.graph_pool, index,
452
+ len(self.compile_submod_names), sym_shape_indices,
453
+ compiled_graph_for_general_shape, self.vllm_backend)
454
+
455
+ compilation_counter.num_piecewise_capturable_graphs_seen += 1
456
+
457
+ return output
458
+
459
+
460
+ class VllmBackend:
461
+ """The compilation backend for `torch.compile` with VLLM.
462
+ It is used for compilation level of `CompilationLevel.PIECEWISE`,
463
+ where we customize the compilation.
464
+
465
+ The major work of this backend is to split the graph into
466
+ piecewise graphs, and pass them to the piecewise backend.
467
+
468
+ This backend also adds the PostGradPassManager to Inductor config,
469
+ which handles the post-grad passes.
470
+ """
471
+
472
+ vllm_config: VllmConfig
473
+ compilation_config: CompilationConfig
474
+ graph_pool: Any
475
+ _called: bool = False
476
+ # the graph we compiled
477
+ graph: fx.GraphModule
478
+ # the stiching graph module for all the piecewise graphs
479
+ split_gm: fx.GraphModule
480
+ piecewise_graphs: List[SplitItem]
481
+ returned_callable: Callable
482
+ # Inductor passes to run on the graph pre-defunctionalization
483
+ post_grad_passes: Sequence[Callable]
484
+ sym_tensor_indices: List[int]
485
+ input_buffers: List[torch.Tensor]
486
+ inductor_hash_cache: InductorHashCache
487
+
488
+ def __init__(
489
+ self,
490
+ vllm_config: VllmConfig,
491
+ ):
492
+ global global_graph_pool
493
+ if global_graph_pool is None:
494
+ global_graph_pool = torch.cuda.graph_pool_handle()
495
+
496
+ # TODO: in the future, if we want to use multiple
497
+ # streams, it might not be safe to share a global pool.
498
+ # only investigate this when we use multiple streams
499
+ self.graph_pool = global_graph_pool
500
+
501
+ # Passes to run on the graph post-grad.
502
+ self.post_grad_pass_manager = PostGradPassManager()
503
+
504
+ self.sym_tensor_indices = []
505
+ self.input_buffers = []
506
+
507
+ self.vllm_config = vllm_config
508
+ self.compilation_config = vllm_config.compilation_config
509
+
510
+ # `torch.compile` is JIT compiled, so we don't need to
511
+ # do anything here
512
+
513
+ def configure_post_pass(self):
514
+ config = self.compilation_config
515
+ self.post_grad_pass_manager.configure(config.pass_config)
516
+
517
+ # Post-grad custom passes are run using the post_grad_custom_post_pass
518
+ # hook. If a pass for that hook exists, add it to the pass manager.
519
+ inductor_config = config.inductor_compile_config
520
+ PASS_KEY = "post_grad_custom_post_pass"
521
+ if PASS_KEY in inductor_config:
522
+ # Config should automatically wrap all inductor passes
523
+ assert isinstance(inductor_config[PASS_KEY], InductorPass)
524
+ self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
525
+ inductor_config[PASS_KEY] = self.post_grad_pass_manager
526
+
527
+ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
528
+
529
+ vllm_config = self.vllm_config
530
+ if not self.compilation_config.cache_dir:
531
+ # no provided cache dir, generate one based on the known factors
532
+ # that affects the compilation. if none of the factors change,
533
+ # the cache dir will be the same so that we can reuse the compiled
534
+ # graph.
535
+
536
+ # 1. factors come from the vllm_config (it mainly summarizes how the
537
+ # model is created)
538
+ config_hash = vllm_config.compute_hash()
539
+
540
+ # 2. factors come from the code files that are traced by Dynamo (
541
+ # it mainly summarizes how the model is used in forward pass)
542
+ forward_code_files = list(
543
+ sorted(self.compilation_config.traced_files))
544
+ self.compilation_config.traced_files.clear()
545
+ logger.debug(
546
+ "Traced files (to be considered for compilation cache):\n%s",
547
+ "\n".join(forward_code_files))
548
+ hash_content = []
549
+ for filepath in forward_code_files:
550
+ hash_content.append(filepath)
551
+ with open(filepath) as f:
552
+ hash_content.append(f.read())
553
+ import hashlib
554
+ code_hash = hashlib.md5(
555
+ "\n".join(hash_content).encode()).hexdigest()
556
+
557
+ # combine the two hashes to generate the cache dir
558
+ hash_key = hashlib.md5(
559
+ f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
560
+ cache_dir = os.path.join(
561
+ envs.VLLM_CACHE_ROOT,
562
+ "torch_compile_cache",
563
+ hash_key,
564
+ )
565
+ self.compilation_config.cache_dir = cache_dir
566
+
567
+ cache_dir = self.compilation_config.cache_dir
568
+ os.makedirs(cache_dir, exist_ok=True)
569
+ local_cache_dir = os.path.join(
570
+ cache_dir, f"rank_{vllm_config.parallel_config.rank}")
571
+ self.compilation_config.local_cache_dir = local_cache_dir
572
+
573
+ disabled = envs.VLLM_DISABLE_COMPILE_CACHE
574
+ self.inductor_hash_cache: InductorHashCache = InductorHashCache(
575
+ local_cache_dir, disabled=disabled)
576
+ if disabled:
577
+ logger.info("vLLM's torch.compile cache is disabled.")
578
+ else:
579
+ logger.info("Using cache directory: %s for vLLM's torch.compile",
580
+ local_cache_dir)
581
+
582
+ # when dynamo calls the backend, it means the bytecode
583
+ # transform and analysis are done
584
+ compilation_counter.num_graphs_seen += 1
585
+ from .monitor import torch_compile_start_time
586
+ dynamo_time = time.time() - torch_compile_start_time
587
+ logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
588
+ self.compilation_config.compilation_time += dynamo_time
589
+
590
+ # we control the compilation process, each instance can only be
591
+ # called once
592
+ assert not self._called, "VllmBackend can only be called once"
593
+
594
+ self.graph = graph
595
+ self.configure_post_pass()
596
+
597
+ self.split_gm, self.piecewise_graphs = split_graph(
598
+ graph, self.compilation_config.splitting_ops)
599
+
600
+ from torch._dynamo.utils import lazy_format_graph_code
601
+
602
+ # depyf will hook lazy_format_graph_code and dump the graph
603
+ # for debugging, no need to print the graph here
604
+ lazy_format_graph_code("before split", self.graph)
605
+ lazy_format_graph_code("after split", self.split_gm)
606
+
607
+ compilation_counter.num_piecewise_graphs_seen += len(
608
+ self.piecewise_graphs)
609
+ submod_names_to_compile = [
610
+ item.submod_name for item in self.piecewise_graphs
611
+ if not item.is_splitting_graph
612
+ ]
613
+
614
+ # propagate the split graph to the piecewise backend,
615
+ # compile submodules with symbolic shapes
616
+ PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
617
+ self.vllm_config, self.graph_pool,
618
+ self).run(*example_inputs)
619
+
620
+ graph_path = os.path.join(local_cache_dir, "computation_graph.py")
621
+ if not os.path.exists(graph_path):
622
+ # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
623
+ # use `print_readable` because it can include submodules
624
+ src = "from __future__ import annotations\nimport torch\n" + \
625
+ self.split_gm.print_readable(print_output=False)
626
+ src = src.replace("<lambda>", "GraphModule")
627
+ with open(graph_path, "w") as f:
628
+ f.write(src)
629
+
630
+ logger.debug("Computation graph saved to %s", graph_path)
631
+
632
+ self._called = True
633
+
634
+ if not self.compilation_config.use_cudagraph or \
635
+ not self.compilation_config.cudagraph_copy_inputs:
636
+ return self.split_gm
637
+
638
+ # if we need to copy input buffers for cudagraph
639
+ from torch._guards import detect_fake_mode
640
+ fake_mode = detect_fake_mode()
641
+ fake_args = [
642
+ fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
643
+ for t in example_inputs
644
+ ]
645
+
646
+ # index of tensors that have symbolic shapes (batch size)
647
+ # for weights and static buffers, they will have concrete shapes.
648
+ # symbolic shape only happens for input tensors.
649
+ from torch.fx.experimental.symbolic_shapes import is_symbolic
650
+ self.sym_tensor_indices = [
651
+ i for i, x in enumerate(fake_args)
652
+ if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
653
+ any(is_symbolic(d) for d in x.size())
654
+ ]
655
+
656
+ # compiler managed cudagraph input buffers
657
+ # we assume the first run with symbolic shapes
658
+ # has the maximum size among all the tensors
659
+ self.input_buffers = [
660
+ example_inputs[x].clone() for x in self.sym_tensor_indices
661
+ ]
662
+
663
+ # this is the callable we return to Dynamo to run
664
+ def copy_and_call(*args):
665
+ list_args = list(args)
666
+ for i, index in enumerate(self.sym_tensor_indices):
667
+ runtime_tensor = list_args[index]
668
+ runtime_shape = runtime_tensor.shape[0]
669
+ static_tensor = self.input_buffers[i][:runtime_shape]
670
+
671
+ # copy the tensor to the static buffer
672
+ static_tensor.copy_(runtime_tensor)
673
+
674
+ # replace the tensor in the list_args to the static buffer
675
+ list_args[index] = static_tensor
676
+ return self.split_gm(*list_args)
677
+
678
+ return copy_and_call
679
+
680
+
681
+ @dataclasses.dataclass
682
+ class ConcreteSizeEntry:
683
+ runtime_shape: int
684
+ need_to_compile: bool # the size is in compile_sizes
685
+ use_cudagraph: bool # the size is in cudagraph_capture_sizes
686
+
687
+ compiled: bool = False
688
+ runnable: Callable = None # type: ignore
689
+ num_finished_warmup: int = 0
690
+ cudagraph: Optional[torch.cuda.CUDAGraph] = None
691
+ output: Optional[Any] = None
692
+
693
+ # for cudagraph debugging, track the input addresses
694
+ # during capture, and check if they are the same during replay
695
+ input_addresses: Optional[List[int]] = None
696
+
697
+
698
+ class PiecewiseBackend:
699
+
700
+ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
701
+ graph_pool: Any, piecewise_compile_index: int,
702
+ total_piecewise_compiles: int, sym_shape_indices: List[int],
703
+ compiled_graph_for_general_shape: Callable,
704
+ vllm_backend: VllmBackend):
705
+ """
706
+ The backend for piecewise compilation.
707
+ It mainly handles the compilation and cudagraph capturing.
708
+
709
+ We will compile `self.graph` once for the general shape,
710
+ and then compile for different shapes specified in
711
+ `compilation_config.compile_sizes`.
712
+
713
+ Independently, we will capture cudagraph for different shapes.
714
+
715
+ If a shape needs both compilation and cudagraph, we will
716
+ compile it first, and then capture cudagraph.
717
+ """
718
+ self.graph = graph
719
+ self.vllm_config = vllm_config
720
+ self.compilation_config = vllm_config.compilation_config
721
+ self.graph_pool = graph_pool
722
+ self.piecewise_compile_index = piecewise_compile_index
723
+ self.total_piecewise_compiles = total_piecewise_compiles
724
+ self.vllm_backend = vllm_backend
725
+
726
+ self.is_first_graph = piecewise_compile_index == 0
727
+ self.is_last_graph = (
728
+ piecewise_compile_index == total_piecewise_compiles - 1)
729
+
730
+ self.compile_sizes: Set[int] = set(
731
+ self.compilation_config.compile_sizes)
732
+ self.cudagraph_capture_sizes: Set[int] = set(
733
+ self.compilation_config.cudagraph_capture_sizes
734
+ ) if self.compilation_config.use_cudagraph else set()
735
+
736
+ self.first_run_finished = False
737
+
738
+ self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
739
+
740
+ self.sym_shape_indices = sym_shape_indices
741
+
742
+ self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
743
+
744
+ # the entries for different shapes that we need to either
745
+ # compile or capture cudagraph
746
+ self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
747
+
748
+ # to_be_compiled_sizes tracks the remaining sizes to compile,
749
+ # and updates during the compilation process, so we need to copy it
750
+ self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
751
+ for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
752
+ self.concrete_size_entries[shape] = ConcreteSizeEntry(
753
+ runtime_shape=shape,
754
+ need_to_compile=shape in self.compile_sizes,
755
+ use_cudagraph=shape in self.cudagraph_capture_sizes,
756
+ )
757
+
758
+ def check_for_ending_compilation(self):
759
+ if self.is_last_graph and not self.to_be_compiled_sizes:
760
+ # no specific sizes to compile
761
+ # save the hash of the inductor graph for the next run
762
+ self.vllm_backend.inductor_hash_cache.save_to_file()
763
+ end_monitoring_torch_compile(self.vllm_config)
764
+
765
+ def __call__(self, *args) -> Any:
766
+ if not self.first_run_finished:
767
+ self.first_run_finished = True
768
+ self.check_for_ending_compilation()
769
+ return self.compiled_graph_for_general_shape(*args)
770
+
771
+ runtime_shape = args[self.sym_shape_indices[0]]
772
+ if runtime_shape not in self.concrete_size_entries:
773
+ # we don't need to do anything for this shape
774
+ return self.compiled_graph_for_general_shape(*args)
775
+
776
+ entry = self.concrete_size_entries[runtime_shape]
777
+
778
+ if entry.runnable is None:
779
+ entry.runnable = self.compiled_graph_for_general_shape
780
+
781
+ if entry.need_to_compile and not entry.compiled:
782
+ entry.compiled = True
783
+ self.to_be_compiled_sizes.remove(runtime_shape)
784
+ # args are real arguments
785
+ entry.runnable = wrap_inductor(
786
+ self.graph,
787
+ args,
788
+ self.compilation_config.inductor_compile_config,
789
+ self.compilation_config,
790
+ self.vllm_backend,
791
+ graph_index=self.piecewise_compile_index,
792
+ num_graphs=self.total_piecewise_compiles,
793
+ runtime_shape=runtime_shape,
794
+ use_inductor=self.compilation_config.use_inductor)
795
+
796
+ # finished compilations for all required shapes
797
+ if self.is_last_graph and not self.to_be_compiled_sizes:
798
+ self.check_for_ending_compilation()
799
+
800
+ if not entry.use_cudagraph:
801
+ return entry.runnable(*args)
802
+
803
+ if entry.cudagraph is None:
804
+ if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
805
+ entry.num_finished_warmup += 1
806
+ if self.is_first_graph:
807
+ logger.debug(
808
+ "Warming up %s/%s for shape %s",
809
+ entry.num_finished_warmup,
810
+ self.compilation_config.cudagraph_num_of_warmups,
811
+ runtime_shape)
812
+ return entry.runnable(*args)
813
+
814
+ if self.is_first_graph:
815
+ # Since we capture cudagraph for many different shapes and
816
+ # capturing is fast, we don't need to log it for every shape.
817
+ # We only log it in the debug mode.
818
+ logger.debug("Capturing a cudagraph for shape %s",
819
+ runtime_shape)
820
+
821
+ input_addresses = [
822
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
823
+ ]
824
+ entry.input_addresses = input_addresses
825
+ cudagraph = torch.cuda.CUDAGraph()
826
+
827
+ with ExitStack() as stack:
828
+ if not self.is_first_graph:
829
+ # during every model forward, we will capture
830
+ # many pieces of cudagraphs (roughly one per layer).
831
+ # running gc again and again across layers will
832
+ # make the cudagraph capture very slow.
833
+ # therefore, we only run gc for the first graph,
834
+ # and disable gc for the rest of the graphs.
835
+ stack.enter_context(patch("gc.collect", lambda: None))
836
+ stack.enter_context(
837
+ patch("torch.cuda.empty_cache", lambda: None))
838
+
839
+ # mind-exploding: carefully manage the reference and memory.
840
+ with torch.cuda.graph(cudagraph, pool=self.graph_pool):
841
+ # `output` is managed by pytorch's cudagraph pool
842
+ output = entry.runnable(*args)
843
+ if self.is_last_graph:
844
+ # by converting it to weak ref,
845
+ # the original `output` will immediately be released
846
+ # to save memory. It is only safe to do this for
847
+ # the last graph, because the output of the last graph
848
+ # will not be used by any other cuda graph.
849
+ output = weak_ref_tensors(output)
850
+
851
+ # here we always use weak ref for the output
852
+ # to save memory
853
+ entry.output = weak_ref_tensors(output)
854
+ entry.cudagraph = cudagraph
855
+
856
+ compilation_counter.num_cudagraph_caputured += 1
857
+
858
+ # important: we need to return the output, rather than
859
+ # the weak ref of the output, so that pytorch can correctly
860
+ # manage the memory during cuda graph capture
861
+ return output
862
+
863
+ if self.is_debugging_mode:
864
+ # check if the input addresses are the same
865
+ new_input_addresses = [
866
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
867
+ ]
868
+ assert new_input_addresses == entry.input_addresses, (
869
+ "Input addresses for cudagraphs are different during replay."
870
+ f" Expected {entry.input_addresses}, got {new_input_addresses}"
871
+ )
872
+
873
+ entry.cudagraph.replay()
874
+ return entry.output
.venv/lib/python3.11/site-packages/vllm/compilation/counter.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import copy
4
+ import dataclasses
5
+ from contextlib import contextmanager
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class CompilationCounter:
10
+ num_models_seen: int = 0
11
+ num_graphs_seen: int = 0
12
+ # including the splitting ops
13
+ num_piecewise_graphs_seen: int = 0
14
+ # not including the splitting ops
15
+ num_piecewise_capturable_graphs_seen: int = 0
16
+ num_inductor_compilations: int = 0
17
+ num_cudagraph_caputured: int = 0
18
+
19
+ def clone(self) -> "CompilationCounter":
20
+ return copy.deepcopy(self)
21
+
22
+ @contextmanager
23
+ def expect(self, **kwargs):
24
+ old = self.clone()
25
+ yield
26
+ for k, v in kwargs.items():
27
+ assert getattr(self, k) - getattr(old, k) == v, (
28
+ f"{k} not as expected, before it is {getattr(old, k)}"
29
+ f", after it is {getattr(self, k)}, "
30
+ f"expected diff is {v}")
31
+
32
+
33
+ compilation_counter = CompilationCounter()
.venv/lib/python3.11/site-packages/vllm/compilation/decorators.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import inspect
4
+ from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
5
+ from unittest.mock import patch
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
10
+
11
+ from vllm.compilation.counter import compilation_counter
12
+ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
13
+ from vllm.config import CompilationLevel, VllmConfig
14
+ from vllm.logger import init_logger
15
+ from vllm.sequence import IntermediateTensors
16
+ from vllm.utils import supports_dynamo
17
+
18
+ from .monitor import start_monitoring_torch_compile
19
+
20
+ logger = init_logger(__name__)
21
+
22
+ _T = TypeVar("_T", bound=type[nn.Module])
23
+
24
+
25
+ @overload
26
+ def support_torch_compile(
27
+ *,
28
+ dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
29
+ ) -> Callable[[_T], _T]:
30
+ ...
31
+
32
+
33
+ @overload
34
+ def support_torch_compile(cls: _T) -> _T:
35
+ ...
36
+
37
+
38
+ def support_torch_compile(
39
+ cls: Optional[_T] = None,
40
+ *,
41
+ dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
42
+ ) -> Union[Callable[[_T], _T], _T]:
43
+ """
44
+ A decorator to add support for compiling the forward method of a class.
45
+
46
+ Usage 1: use directly as a decorator without arguments:
47
+
48
+ ```python
49
+ @support_torch_compile
50
+ class MyModel(nn.Module):
51
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
52
+ ...
53
+ ```
54
+
55
+ Usage 2: use as a decorator with arguments:
56
+
57
+ ```python
58
+ @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
59
+ class MyModel(nn.Module):
60
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
61
+ ...
62
+ ```
63
+
64
+ `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
65
+ dimensions of the argument. The dynamic dimensions can be either a single
66
+ integer or a list of integers.
67
+
68
+ if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
69
+ of the `forward` method, based on the following default rules:
70
+
71
+ - if the argument is annotated as `torch.Tensor` or
72
+ `Optional[torch.Tensor]`, the first dimension will be
73
+ marked as dynamic.
74
+ - if the argument is annotated as `IntermediateTensors`, the first
75
+ dimension of all the tensors in the intermediate tensors
76
+ will be marked as dynamic.
77
+
78
+ During runtime, when we actually mark dimensions of tensors,
79
+ it depends on the value of arguments:
80
+
81
+ - if it is a single integer (can be negative), the corresponding dimension
82
+ of the argument will be marked as dynamic.
83
+ - if it is `None`, ignored.
84
+ - if it is `IntermediateTensors`, all the tensors in the intermediate
85
+ tensors will be marked as dynamic.
86
+ - otherwise, it will raise an error.
87
+
88
+ NOTE: if an argument is `None`, it should always be passed as `None` during
89
+ the lifetime of the model, otherwise, it cannot be captured as a single
90
+ computation graph.
91
+ """
92
+
93
+ def cls_decorator_helper(cls: _T) -> _T:
94
+ # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
95
+ # to avoid too much indentation for `_support_torch_compile``
96
+ if not hasattr(cls, 'forward'):
97
+ raise TypeError("decorated class should have a forward method.")
98
+ sig = inspect.signature(cls.forward)
99
+ inferred_dynamic_arg_dims = dynamic_arg_dims
100
+ if inferred_dynamic_arg_dims is None:
101
+ inferred_dynamic_arg_dims = {}
102
+ for k, v in sig.parameters.items():
103
+ if v.annotation in [
104
+ torch.Tensor, Optional[torch.Tensor],
105
+ IntermediateTensors, Optional[IntermediateTensors]
106
+ ]:
107
+ inferred_dynamic_arg_dims[k] = 0
108
+
109
+ logger.debug(("Inferred dynamic dimensions for "
110
+ "forward method of %s: %s"), cls,
111
+ list(inferred_dynamic_arg_dims.keys()))
112
+
113
+ if len(inferred_dynamic_arg_dims) == 0:
114
+ raise ValueError(
115
+ "No dynamic dimensions found in the forward method of "
116
+ f"{cls}. Please provide dynamic_arg_dims explicitly.")
117
+
118
+ for k in inferred_dynamic_arg_dims:
119
+ if k not in sig.parameters:
120
+ raise ValueError(
121
+ f"Argument {k} not found in the forward method of {cls}")
122
+ return _support_torch_compile(cls, inferred_dynamic_arg_dims)
123
+
124
+ if cls is not None:
125
+ # use `support_torch_compile` as a decorator without arguments
126
+ assert isinstance(cls, type)
127
+ return cls_decorator_helper(cls)
128
+
129
+ return cls_decorator_helper
130
+
131
+
132
+ def _support_torch_compile(
133
+ cls: _T,
134
+ dynamic_arg_dims: Dict[str, Union[int, List[int]]],
135
+ ) -> _T:
136
+ """
137
+ A decorator to add support for compiling the forward method of a class.
138
+ """
139
+ if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
140
+ # support decorating multiple times
141
+ return cls
142
+
143
+ # take care of method resolution order
144
+ # make sure super().__init__ is called on the base class
145
+ # other than TorchCompileWrapperWithCustomDispatcher
146
+ cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
147
+
148
+ old_init = cls.__init__
149
+
150
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
151
+ old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
152
+ self.vllm_config = vllm_config
153
+ # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
154
+ # will handle the compilation, so we don't need to do anything here.
155
+ self.do_not_compile = \
156
+ vllm_config.compilation_config.level in [
157
+ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
158
+ ] or not supports_dynamo()
159
+ if self.do_not_compile:
160
+ return
161
+ compilation_counter.num_models_seen += 1
162
+ TorchCompileWrapperWithCustomDispatcher.__init__(
163
+ self, compilation_level=vllm_config.compilation_config.level)
164
+
165
+ cls.__init__ = __init__
166
+
167
+ def __call__(self, *args, **kwargs):
168
+ # torch.compiler.is_compiling() means we are inside the compilation
169
+ # e.g. TPU has the compilation logic in model runner, so we don't
170
+ # need to compile the model inside.
171
+ if self.do_not_compile or torch.compiler.is_compiling():
172
+ return self.forward(*args, **kwargs)
173
+
174
+ # the first compilation needs to have dynamic shapes marked
175
+ if len(self.compiled_codes) < 1:
176
+ sig = inspect.signature(self.__class__.forward)
177
+ bound_args = sig.bind(self, *args, **kwargs)
178
+ bound_args.apply_defaults()
179
+ for k, dims in dynamic_arg_dims.items():
180
+ arg = bound_args.arguments.get(k)
181
+ if arg is not None:
182
+ dims = [dims] if isinstance(dims, int) else dims
183
+ if isinstance(arg, torch.Tensor):
184
+ # In case dims is specified with negative indexing
185
+ dims = [
186
+ arg.ndim + dim if dim < 0 else dim for dim in dims
187
+ ]
188
+ torch._dynamo.mark_dynamic(arg, dims)
189
+ elif isinstance(arg, IntermediateTensors):
190
+ for tensor in arg.tensors.values():
191
+ # In case dims is specified with negative indexing
192
+ dims = [
193
+ tensor.ndim + dim if dim < 0 else dim
194
+ for dim in dims
195
+ ]
196
+ torch._dynamo.mark_dynamic(tensor, dims)
197
+ else:
198
+ raise ValueError(
199
+ "Unsupported dynamic dimensions"
200
+ f" {dims} for argument {k} with type {type(arg)}.")
201
+ # here, it is the starting point of the `torch.compile` process
202
+ start_monitoring_torch_compile(self.vllm_config)
203
+ logger.debug("Start compiling function %s",
204
+ self.original_code_object)
205
+
206
+ # if we don't use custom dispatcher, we can directly call the
207
+ # compiled function and let torch.compile handle the dispatching,
208
+ # with the overhead of guard evaluation and recompilation.
209
+ if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
210
+ # it seems Dynamo reuse the compilation across instances,
211
+ # while we need to make sure the compiled code is not reused.
212
+ # we need to control all the compilation of the model.
213
+ torch._dynamo.eval_frame.remove_from_cache(
214
+ self.original_code_object)
215
+
216
+ # collect all relevant files traced by Dynamo,
217
+ # so that the compilation cache can trigger re-compilation
218
+ # properly when any of these files change.
219
+
220
+ # 1. the file containing the top-level forward function
221
+ self.vllm_config.compilation_config.traced_files.add(
222
+ self.original_code_object.co_filename)
223
+
224
+ # 2. every time Dynamo sees a function call, it will inline
225
+ # the function by calling InliningInstructionTranslator.inline_call
226
+ # we hijack this function to know all the functions called
227
+ # during Dynamo tracing, and their corresponding files
228
+ inline_call = InliningInstructionTranslator.inline_call
229
+
230
+ def patched_inline_call(parent, func, args, kwargs):
231
+ code = func.get_code()
232
+ self.vllm_config.compilation_config.traced_files.add(
233
+ code.co_filename)
234
+ return inline_call(parent, func, args, kwargs)
235
+
236
+ with patch.object(InliningInstructionTranslator, 'inline_call',
237
+ patched_inline_call):
238
+ output = self.compiled_callable(*args, **kwargs)
239
+ return output
240
+
241
+ # usually, capturing the model once is enough, and then we can
242
+ # dispatch to the compiled code directly, without going through
243
+ # the Dynamo guard mechanism.
244
+ with self.dispatch_to_code(0):
245
+ model_output = self.forward(*args, **kwargs)
246
+ return model_output
247
+
248
+ cls.__call__ = __call__
249
+ return cls
.venv/lib/python3.11/site-packages/vllm/compilation/fix_functionalization.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import operator
4
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
8
+
9
+ from vllm.logger import init_logger
10
+
11
+ from .fx_utils import is_func
12
+ from .vllm_inductor_pass import VllmInductorPass
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class FixFunctionalizationPass(VllmInductorPass):
18
+ """
19
+ This pass defunctionalizes certain nodes to avoid redundant tensor copies.
20
+ After this pass, DCE (dead-code elimination) should never be run,
21
+ as de-functionalized nodes may appear as dead code.
22
+
23
+ To add new nodes to defunctionalize, add to the if-elif chain in __call__.
24
+ """
25
+
26
+ def __call__(self, graph: torch.fx.Graph):
27
+ self.begin()
28
+ self.dump_graph(graph, "before_fix_functionalization")
29
+
30
+ self.nodes_to_remove: List[torch.fx.Node] = []
31
+ count = 0
32
+ for node in graph.nodes:
33
+ if not is_func(node, auto_functionalized):
34
+ continue # Avoid deep if-elif nesting
35
+
36
+ kwargs = node.kwargs
37
+ at_target = node.args[0]
38
+
39
+ if at_target == torch.ops._C.rotary_embedding.default:
40
+ query = kwargs['query']
41
+ mm_node = query.args[0].args[0]
42
+
43
+ # rotary_embedding is a special case: the two mutating inputs
44
+ # are query and key, which are slices of mm_node.
45
+ # While functionalized, results at[1] and at[2] are scattered
46
+ # back into mm_node. After de-functionalization, we can just
47
+ # use mm_node directly.
48
+ for idx, user in self.getitem_users(node).items():
49
+ for user_of_getitem in user.users:
50
+ if is_func(user_of_getitem,
51
+ torch.ops.aten.slice_scatter.default):
52
+ user_of_getitem.replace_all_uses_with(mm_node)
53
+ self._remove(user_of_getitem)
54
+ self._remove(user)
55
+
56
+ self.insert_defunctionalized(graph, node)
57
+ self._remove(node)
58
+
59
+ # rms_norm replacements avoid the most copies for LLaMa.
60
+ elif at_target == torch.ops._C.fused_add_rms_norm.default:
61
+ mutated_args = {1: 'input', 2: 'residual'}
62
+ self.defunctionalize(graph, node, mutated_args)
63
+ elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
64
+ mutated_args = {1: 'result', 2: 'residual'}
65
+ self.defunctionalize(graph, node, mutated_args)
66
+ elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
67
+ mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
68
+ self.defunctionalize(graph, node, mutated_args)
69
+ elif at_target in [
70
+ torch.ops._C.rms_norm.default,
71
+ torch.ops._C.rms_norm_static_fp8_quant.default
72
+ ]:
73
+ mutated_args = {1: 'result'}
74
+ self.defunctionalize(graph, node, mutated_args)
75
+
76
+ elif at_target == torch.ops._C.silu_and_mul.default:
77
+ mutated_args = {1: 'out'}
78
+ # Because we have an 'out', need to specify args directly
79
+ self.defunctionalize(graph,
80
+ node,
81
+ mutated_args,
82
+ args=('out', 'input'))
83
+ else:
84
+ continue # skip the count
85
+
86
+ count += 1
87
+
88
+ self.dump_graph(graph, "before_fix_functionalization_cleanup")
89
+
90
+ # Remove the nodes all at once
91
+ count_removed = len(self.nodes_to_remove)
92
+ for node in self.nodes_to_remove:
93
+ graph.erase_node(node)
94
+
95
+ logger.debug("De-functionalized %s nodes, removed %s nodes", count,
96
+ count_removed)
97
+ self.dump_graph(graph, "after_fix_functionalization")
98
+ self.end_and_log()
99
+
100
+ def _remove(self, node_or_nodes: Union[torch.fx.Node,
101
+ Iterable[torch.fx.Node]]):
102
+ """
103
+ Stage a node (or nodes) for removal at the end of the pass.
104
+ """
105
+ if isinstance(node_or_nodes, torch.fx.Node):
106
+ self.nodes_to_remove.append(node_or_nodes)
107
+ else:
108
+ self.nodes_to_remove.extend(node_or_nodes)
109
+
110
+ def defunctionalize(self,
111
+ graph: torch.fx.Graph,
112
+ node: torch.fx.Node,
113
+ mutated_args: Dict[int, Union[torch.fx.Node, str]],
114
+ args: Optional[Tuple[Union[torch.fx.Node, str],
115
+ ...]] = None):
116
+ """
117
+ De-functionalize a node by replacing it with a call to the original.
118
+ It also replaces the getitem users with the mutated arguments.
119
+ See replace_users_with_mutated_args and insert_defunctionalized.
120
+ """
121
+ self.replace_users_with_mutated_args(node, mutated_args)
122
+ self.insert_defunctionalized(graph, node, args=args)
123
+ self._remove(node)
124
+
125
+ def replace_users_with_mutated_args(self, node: torch.fx.Node,
126
+ mutated_args: Dict[int,
127
+ Union[torch.fx.Node,
128
+ str]]):
129
+ """
130
+ Replace all getitem users of the auto-functionalized node with the
131
+ mutated arguments.
132
+ :param node: The auto-functionalized node
133
+ :param mutated_args: The mutated arguments, indexed by getitem index.
134
+ If the value of an arg is a string, `node.kwargs[arg]` is used.
135
+ """
136
+ for idx, user in self.getitem_users(node).items():
137
+ arg = mutated_args[idx]
138
+ arg = node.kwargs[arg] if isinstance(arg, str) else arg
139
+ user.replace_all_uses_with(arg)
140
+ self._remove(user)
141
+
142
+ def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
143
+ """
144
+ Returns the operator.getitem users of the auto-functionalized node,
145
+ indexed by the index they are getting.
146
+ """
147
+ users = {}
148
+ for user in node.users:
149
+ if is_func(user, operator.getitem):
150
+ idx = user.args[1]
151
+ users[idx] = user
152
+ return users
153
+
154
+ def insert_defunctionalized(self,
155
+ graph: torch.fx.Graph,
156
+ node: torch.fx.Node,
157
+ args: Optional[Tuple[Union[torch.fx.Node, str],
158
+ ...]] = None):
159
+ """
160
+ Insert a new defunctionalized node into the graph before node.
161
+ If one of the kwargs is 'out', provide args directly,
162
+ as node.kwargs cannot be used.
163
+ See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
164
+
165
+ :param graph: Graph to insert the defunctionalized node into
166
+ :param node: The auto-functionalized node to defunctionalize
167
+ :param args: If we cannot use kwargs, specify args directly.
168
+ If an arg is a string, `node.kwargs[arg]` is used.
169
+ """ # noqa: E501
170
+ assert is_func(node, auto_functionalized), \
171
+ f"node must be auto-functionalized, is {node} instead"
172
+
173
+ # Create a new call to the original function
174
+ with graph.inserting_before(node):
175
+ function = node.args[0]
176
+ if args is None:
177
+ graph.call_function(function, kwargs=node.kwargs)
178
+ else:
179
+ # Args passed as strings refer to items in node.kwargs
180
+ args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
181
+ for arg in args)
182
+ graph.call_function(function, args=args)
.venv/lib/python3.11/site-packages/vllm/compilation/fusion.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
4
+
5
+ import torch
6
+ import torch._inductor.pattern_matcher as pm
7
+ # TODO(luka) use vllm.utils once #10836 landed
8
+ from compressed_tensors.quantization import FP8_DTYPE
9
+ from torch import fx
10
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
11
+ from torch._inductor.pattern_matcher import PatternMatcherPass
12
+ from torch._ops import OpOverload
13
+
14
+ from vllm.config import CompilationConfig
15
+ from vllm.logger import init_logger
16
+
17
+ from .fx_utils import find_getitem_maybe
18
+ from .multi_output_match import MultiOutputMatch
19
+ from .vllm_inductor_pass import VllmInductorPass
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ def empty_bf16(*args, **kwargs):
25
+ return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
26
+
27
+
28
+ def empty_fp32(*args, **kwargs):
29
+ return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
30
+
31
+
32
+ RMS_OP = torch.ops._C.rms_norm.default
33
+ RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
34
+
35
+
36
+ class QuantKey(NamedTuple):
37
+ """
38
+ Named tuple for identifying the type of quantization.
39
+ dtype: quantized data type
40
+ static: static quantization if True, dynamic if False
41
+ per_tensor: per-tensor quantization if True, per-token if False
42
+ symmetric: symmetric if True, asymmetric if False
43
+ """
44
+ dtype: torch.dtype
45
+ static: bool
46
+ per_tensor: bool = True
47
+ symmetric: bool = True
48
+
49
+ def __str__(self):
50
+ return (f"QuantKey({'static' if self.static else 'dynamic'},"
51
+ f"{fx.graph.dtype_abbrs[self.dtype]},"
52
+ f"{'per_tensor' if self.per_tensor else 'per_token'},"
53
+ f"{'a' if not self.symmetric else ''}symmetric)")
54
+
55
+
56
+ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
57
+ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
58
+ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
59
+
60
+ QUANT_OPS: Dict[QuantKey, OpOverload] = {
61
+ kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
62
+ kFp8DynamicTensorSym:
63
+ torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
64
+ kFp8DynamicTokenSym:
65
+ torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
66
+ }
67
+
68
+
69
+ class FusedRMSQuantKey(NamedTuple):
70
+ """
71
+ Named tuple for identifying the type of RMSNorm + quant fusion.
72
+ quant: type of quantization
73
+ fused_add: does the op also perform the residual add
74
+ """
75
+ quant: QuantKey
76
+ fused_add: bool
77
+
78
+ def __str__(self):
79
+ return (f"FusedQuantKey({self.quant}, with"
80
+ f"{'' if self.fused_add else 'out'} residual)")
81
+
82
+
83
+ FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
84
+ FusedRMSQuantKey(kFp8StaticTensorSym, False):
85
+ torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
86
+ FusedRMSQuantKey(kFp8StaticTensorSym, True):
87
+ torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
88
+ FusedRMSQuantKey(kFp8DynamicTokenSym, False):
89
+ torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
90
+ FusedRMSQuantKey(kFp8DynamicTokenSym, True):
91
+ torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
92
+ }
93
+
94
+
95
+ class QuantMultiOutputMatch(MultiOutputMatch):
96
+
97
+ def __init__(self, match: pm.Match, quant_op, fused_op):
98
+ super().__init__(match)
99
+ assert isinstance(quant_op, OpOverload)
100
+ assert isinstance(fused_op, OpOverload)
101
+ self.QUANT_OP = quant_op # in-place quant op
102
+ self.FUSED_OP = fused_op # in-place fused quant op
103
+
104
+ def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
105
+ int]],
106
+ **kwargs):
107
+ """
108
+ This utility function inserts an auto-functionalized node for FUSED_OP.
109
+ It also correctly sets its meta value and rebinds the users of the
110
+ unfused nodes to use the fused node instead.
111
+
112
+ :param fused_return_mapping: A dictionary, mapping from getitem indices
113
+ of the fused node result to a tuple of the old node and a getitem index.
114
+ :param kwargs: kwargs that get directly forwarded to the auto_fn node
115
+
116
+ Example:
117
+ If we want to replace this graph:
118
+ _, x1, x2 = auto_fn(op1)
119
+ _, y1, y2 = auto_fn(op2)
120
+
121
+ with
122
+ _, x1, y2, x2 = auto_fn(FUSED_OP)
123
+
124
+ we would call:
125
+ insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
126
+
127
+ Note that the 0th element is None for auto-functionalized in-place ops.
128
+ Hence, others appear 1-indexed.
129
+ """
130
+ fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
131
+ indices = fused_return_mapping.keys()
132
+ getitem_nodes = self.insert_getitems(fused_node, indices)
133
+
134
+ # Prepare the meta value, use a list so it's mutable
135
+ meta_val = [None] * (max(indices) + 1)
136
+
137
+ # Iterate through elements of the tuple produced by fused_node
138
+ for idx, getitem_node in zip(indices, getitem_nodes):
139
+ old_node, old_idx = fused_return_mapping[idx]
140
+
141
+ # If the old value was never used, the old_getitem might not exist
142
+ old_getitem = find_getitem_maybe(old_node, old_idx)
143
+ if old_getitem is not None:
144
+ # Rebind the users of match getitem nodes to use the new nodes.
145
+ # The old nodes will be removed by DCE at the end of the pass.
146
+ old_getitem.replace_all_uses_with(getitem_node)
147
+ getitem_node.meta["val"] = old_getitem.meta["val"]
148
+
149
+ # Extract the appropriate meta value
150
+ # It is present even if the getitem node does not exist
151
+ meta_val[idx] = old_node.meta["val"][old_idx]
152
+
153
+ # Fix the meta value on the new fused node
154
+ fused_node.meta["val"] = tuple(meta_val)
155
+
156
+
157
+ class RMSNormQuantPattern:
158
+
159
+ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
160
+ self.epsilon = epsilon
161
+ self.quant_dtype = key.quant.dtype
162
+
163
+ assert key.quant in QUANT_OPS, \
164
+ f"unsupported quantization scheme {key.quant}"
165
+ self.QUANT_OP = QUANT_OPS[key.quant]
166
+
167
+ assert key in FUSED_OPS, \
168
+ f"unsupported fused rmsnorm+quant op for {key}"
169
+ self.FUSED_OP = FUSED_OPS[key]
170
+
171
+
172
+ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
173
+
174
+ def __init__(self,
175
+ epsilon: float,
176
+ quant_dtype: torch.dtype,
177
+ symmetric=True):
178
+ fused_key = FusedRMSQuantKey(fused_add=False,
179
+ quant=QuantKey(dtype=quant_dtype,
180
+ static=True,
181
+ per_tensor=True,
182
+ symmetric=symmetric))
183
+ super().__init__(epsilon, fused_key)
184
+
185
+ def register(self, pm_pass: PatternMatcherPass):
186
+ # Cannot use methods, as the self argument affects tracing
187
+ def pattern(result: torch.Tensor, result_rms: torch.Tensor,
188
+ input: torch.Tensor, weight: torch.Tensor,
189
+ scale: torch.Tensor):
190
+ at1 = auto_functionalized(RMS_OP,
191
+ result=result_rms,
192
+ input=input,
193
+ weight=weight,
194
+ epsilon=self.epsilon)
195
+ at2 = auto_functionalized(self.QUANT_OP,
196
+ result=result,
197
+ input=at1[1],
198
+ scale=scale)
199
+
200
+ # result
201
+ return at2[1]
202
+
203
+ def replacement(result: torch.Tensor, result_rms: torch.Tensor,
204
+ input: torch.Tensor, weight: torch.Tensor,
205
+ scale: torch.Tensor):
206
+ at = auto_functionalized(self.FUSED_OP,
207
+ result=result,
208
+ input=input,
209
+ weight=weight,
210
+ scale=scale,
211
+ epsilon=self.epsilon)
212
+
213
+ # result
214
+ return at[1]
215
+
216
+ inputs = [
217
+ torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
218
+ empty_bf16(5, 4), # result_rms
219
+ empty_bf16(5, 4), # input
220
+ empty_bf16(1, 5), # weight
221
+ empty_fp32(1, 1) # scale
222
+ ]
223
+
224
+ pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
225
+ pm_pass)
226
+
227
+
228
+ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
229
+
230
+ def __init__(self,
231
+ epsilon: float,
232
+ quant_dtype: torch.dtype,
233
+ symmetric=True):
234
+ key = FusedRMSQuantKey(fused_add=True,
235
+ quant=QuantKey(dtype=quant_dtype,
236
+ static=True,
237
+ per_tensor=True,
238
+ symmetric=symmetric))
239
+ super().__init__(epsilon, key)
240
+
241
+ def register(self, pm_pass: PatternMatcherPass,
242
+ record_match: Callable[[MultiOutputMatch], bool]):
243
+
244
+ def pattern(result: torch.Tensor, input: torch.Tensor,
245
+ residual: torch.Tensor, weight: torch.Tensor,
246
+ scale: torch.Tensor):
247
+ at = auto_functionalized(RMS_ADD_OP,
248
+ input=input,
249
+ residual=residual,
250
+ weight=weight,
251
+ epsilon=self.epsilon)
252
+ at1 = auto_functionalized(self.QUANT_OP,
253
+ result=result,
254
+ input=at[1],
255
+ scale=scale)
256
+
257
+ # result, residual
258
+ return at1[1], at[2]
259
+
260
+ def replacement(result: torch.Tensor, input: torch.Tensor,
261
+ residual: torch.Tensor, weight: torch.Tensor,
262
+ scale: torch.Tensor):
263
+ at = auto_functionalized(self.FUSED_OP,
264
+ result=result,
265
+ input=input,
266
+ residual=residual,
267
+ weight=weight,
268
+ scale=scale,
269
+ epsilon=self.epsilon)
270
+
271
+ # result, residual
272
+ return at[1], at[2]
273
+
274
+ inputs = [
275
+ torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
276
+ empty_bf16(5, 4), # input
277
+ empty_bf16(5, 4), # residual
278
+ empty_bf16(1, 5), # weight
279
+ empty_fp32(1, 1) # scale
280
+ ]
281
+
282
+ pm.register_replacement(
283
+ pattern,
284
+ replacement,
285
+ inputs,
286
+ pm.fwd_only,
287
+ pm_pass,
288
+ extra_check=lambda m: record_match(
289
+ self.Match(m, self.QUANT_OP, self.FUSED_OP)))
290
+
291
+ class Match(QuantMultiOutputMatch):
292
+
293
+ def process(self):
294
+ # Find the nodes in the match that we need to rebind
295
+ rms_node = self.find_auto_fn(RMS_ADD_OP)
296
+ quant_node = self.find_auto_fn(self.QUANT_OP)
297
+
298
+ assert len(rms_node.users) == 2
299
+ assert len(quant_node.users) == 1
300
+
301
+ # First, insert a new auto_functionalized node for the fused op,
302
+ # as well as getitem nodes to extract the result and residual.
303
+ # The auto_fn node returns a tuple of (None, result, residual).
304
+ #
305
+ # The resulting graph looks like this:
306
+ # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
307
+ # result_node_new = at[1]
308
+ # residual_node_new = at[2]
309
+ with self.inserting_after_match():
310
+ # Missing epsilon, scalars cannot be inputs to the pattern
311
+ kwargs = self.match.kwargs.copy()
312
+
313
+ # 0 is always None
314
+ fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
315
+ self.insert_fused_node(fused_return_mapping,
316
+ epsilon=rms_node.kwargs["epsilon"],
317
+ **kwargs)
318
+
319
+
320
+ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
321
+
322
+ def __init__(self,
323
+ epsilon: float,
324
+ quant_dtype: torch.dtype,
325
+ per_tensor: bool,
326
+ symmetric=True):
327
+ key = FusedRMSQuantKey(fused_add=False,
328
+ quant=QuantKey(dtype=quant_dtype,
329
+ static=False,
330
+ per_tensor=per_tensor,
331
+ symmetric=symmetric))
332
+ super().__init__(epsilon, key)
333
+
334
+ def register(self, pm_pass: PatternMatcherPass,
335
+ record_match: Callable[[MultiOutputMatch], bool]):
336
+
337
+ def pattern(result: torch.Tensor, result_rms: torch.Tensor,
338
+ input: torch.Tensor, weight: torch.Tensor,
339
+ scale: torch.Tensor):
340
+ at1 = auto_functionalized(RMS_OP,
341
+ result=result_rms,
342
+ input=input,
343
+ weight=weight,
344
+ epsilon=self.epsilon)
345
+ at2 = auto_functionalized(self.QUANT_OP,
346
+ result=result,
347
+ input=at1[1],
348
+ scale=scale,
349
+ scale_ub=None)
350
+
351
+ # result, scale
352
+ return at2[1], at2[2]
353
+
354
+ def replacement(result: torch.Tensor, result_rms: torch.Tensor,
355
+ input: torch.Tensor, weight: torch.Tensor,
356
+ scale: torch.Tensor):
357
+ at = auto_functionalized(self.FUSED_OP,
358
+ result=result,
359
+ input=input,
360
+ weight=weight,
361
+ scale=scale,
362
+ epsilon=self.epsilon,
363
+ scale_ub=None,
364
+ residual=None)
365
+
366
+ # result, scale
367
+ return at[1], at[2]
368
+
369
+ inputs = [
370
+ torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
371
+ empty_bf16(5, 4), # result_rms
372
+ empty_bf16(5, 4), # input
373
+ empty_bf16(1, 5), # weight
374
+ empty_fp32(1, 1) # scale
375
+ ]
376
+
377
+ pm.register_replacement(
378
+ pattern,
379
+ replacement,
380
+ inputs,
381
+ pm.fwd_only,
382
+ pm_pass,
383
+ extra_check=lambda m: record_match(
384
+ self.Match(m, self.QUANT_OP, self.FUSED_OP)))
385
+
386
+ class Match(QuantMultiOutputMatch):
387
+
388
+ def process(self):
389
+ # Find the nodes in the match that we need to rebind
390
+ rms_node = self.find_auto_fn(RMS_OP)
391
+ quant_node = self.find_auto_fn(self.QUANT_OP)
392
+
393
+ assert len(rms_node.users) == 1
394
+ assert len(quant_node.users) == 2
395
+
396
+ # First, insert a new auto_functionalized node for the fused op,
397
+ # as well as getitem nodes to extract the result and scale.
398
+ # The auto_fn node returns a tuple of (None, result, scale).
399
+ #
400
+ # The resulting graph looks like this:
401
+ # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
402
+ # result_node_new = at[1]
403
+ # scale_node_new = at[2]
404
+ with self.inserting_after_match():
405
+ # Missing epsilon, scalars cannot be inputs to the pattern
406
+ kwargs = self.match.kwargs.copy()
407
+ del kwargs["result_rms"] # not used in the fused op
408
+
409
+ fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
410
+ self.insert_fused_node(
411
+ fused_return_mapping,
412
+ epsilon=rms_node.kwargs["epsilon"],
413
+ scale_ub=None, # not used but required
414
+ residual=None, # not used but required
415
+ **kwargs)
416
+
417
+
418
+ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
419
+
420
+ def __init__(self,
421
+ epsilon: float,
422
+ quant_dtype: torch.dtype,
423
+ per_tensor: bool = True,
424
+ symmetric=True):
425
+ key = FusedRMSQuantKey(fused_add=True,
426
+ quant=QuantKey(dtype=quant_dtype,
427
+ static=False,
428
+ per_tensor=per_tensor,
429
+ symmetric=symmetric))
430
+ super().__init__(epsilon, key)
431
+
432
+ def register(self, pm_pass: PatternMatcherPass,
433
+ record_match: Callable[[MultiOutputMatch], bool]):
434
+
435
+ def pattern(result: torch.Tensor, input: torch.Tensor,
436
+ residual: torch.Tensor, weight: torch.Tensor,
437
+ scale: torch.Tensor):
438
+ at = auto_functionalized(RMS_ADD_OP,
439
+ input=input,
440
+ residual=residual,
441
+ weight=weight,
442
+ epsilon=self.epsilon)
443
+ at1 = auto_functionalized(self.QUANT_OP,
444
+ result=result,
445
+ input=at[1],
446
+ scale=scale,
447
+ scale_ub=None)
448
+
449
+ # result, residual, scale
450
+ return at1[1], at[2], at1[2]
451
+
452
+ def replacement(result: torch.Tensor, input: torch.Tensor,
453
+ residual: torch.Tensor, weight: torch.Tensor,
454
+ scale: torch.Tensor):
455
+ at = auto_functionalized(self.FUSED_OP,
456
+ result=result,
457
+ input=input,
458
+ weight=weight,
459
+ scale=scale,
460
+ epsilon=self.epsilon,
461
+ scale_ub=None,
462
+ residual=residual)
463
+
464
+ # result, residual, scale
465
+ return at[1], at[3], at[2]
466
+
467
+ inputs = [
468
+ torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
469
+ empty_bf16(5, 4), # input
470
+ empty_bf16(5, 4), # residual
471
+ empty_bf16(1, 5), # weight
472
+ empty_fp32(1, 1) # scale
473
+ ]
474
+
475
+ pm.register_replacement(
476
+ pattern,
477
+ replacement,
478
+ inputs,
479
+ pm.fwd_only,
480
+ pm_pass,
481
+ extra_check=lambda m: record_match(
482
+ self.Match(m, self.QUANT_OP, self.FUSED_OP)))
483
+
484
+ class Match(QuantMultiOutputMatch):
485
+
486
+ def process(self):
487
+ # Find the nodes in the match that we need to rebind
488
+ rms_node = self.find_auto_fn(RMS_ADD_OP)
489
+ quant_node = self.find_auto_fn(self.QUANT_OP)
490
+
491
+ assert len(rms_node.users) == 2
492
+ assert len(quant_node.users) == 2
493
+
494
+ # First, insert a new auto_functionalized node for the fused op,
495
+ # as well as getitem nodes to extract result, scale, and residual.
496
+ # The auto_fn node returns a tuple (None, result, scale, residual).
497
+ #
498
+ # The resulting graph looks like this:
499
+ # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
500
+ # result_node_new = at[1]
501
+ # scale_node_new = at[2]
502
+ # residual_node_new = at[3]
503
+ with self.inserting_after_match():
504
+ # Missing epsilon, scalars cannot be inputs to the pattern
505
+ kwargs = self.match.kwargs.copy()
506
+
507
+ fused_return_mapping = {
508
+ 1: (quant_node, 1), # result
509
+ 2: (quant_node, 2), # scale
510
+ 3: (rms_node, 2), # residual
511
+ }
512
+ self.insert_fused_node(
513
+ fused_return_mapping,
514
+ epsilon=rms_node.kwargs["epsilon"],
515
+ scale_ub=None, # not used but required
516
+ **kwargs)
517
+
518
+
519
+ class FusionPass(VllmInductorPass):
520
+ """
521
+ This pass fuses a pre-defined set of custom ops into fused ops.
522
+ It uses the torch pattern matcher to find the patterns and replace them.
523
+ It also manually processes multi-output matches, as those are broken in
524
+ the torch pattern matcher.
525
+
526
+ Because patterns can only be registered once, the pass is a singleton.
527
+ This will be addressed in a future version of PyTorch:
528
+ https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
529
+ """
530
+
531
+ _instance: 'Optional[FusionPass]' = None
532
+
533
+ @classmethod
534
+ def instance(cls, config: CompilationConfig.PassConfig):
535
+ """
536
+ Get the singleton instance of the FusionPass.
537
+ If the instance exists, the config is updated but
538
+ initialization is not repeated.
539
+ """
540
+ if cls._instance is None:
541
+ cls._instance = FusionPass(config)
542
+ else:
543
+ cls._instance.config = config
544
+ return cls._instance
545
+
546
+ def __init__(self, config: CompilationConfig.PassConfig):
547
+ assert self.__class__._instance is None, \
548
+ "FusionPass singleton instance already exists"
549
+ super().__init__(config)
550
+
551
+ self.matches: List[MultiOutputMatch] = []
552
+ self.patterns: PatternMatcherPass = PatternMatcherPass(
553
+ pass_name="fusion_pass")
554
+
555
+ for epsilon in [1e-5, 1e-6]:
556
+ # Fuse rms_norm + static fp8 quant
557
+ RMSNormStaticQuantPattern(epsilon,
558
+ FP8_DTYPE).register(self.patterns)
559
+
560
+ # Matches for patterns below have 2 or more outputs,
561
+ # so we need to process them manually (see process_matches)
562
+
563
+ # Fuse rms_norm + static fp8 quant
564
+ FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
565
+ self.patterns, self.record_match)
566
+
567
+ # Fuse rms_norm + dynamic per-token fp8 quant
568
+ RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
569
+ per_tensor=False).register(
570
+ self.patterns, self.record_match)
571
+
572
+ # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
573
+ FusedAddRMSNormDynamicQuantPattern(epsilon,
574
+ FP8_DTYPE,
575
+ per_tensor=False).register(
576
+ self.patterns,
577
+ self.record_match)
578
+
579
+ # WARNING: This is a hack to clear the pattern matcher cache
580
+ # and allow multiple values of epsilon.
581
+ torch._inductor.pattern_matcher._seen_patterns.clear()
582
+
583
+ def record_match(self, match: MultiOutputMatch) -> bool:
584
+ # Hijack the extra_check to record the match and
585
+ # save it for post-processing.
586
+ self.matches.append(match)
587
+
588
+ # Return False to prevent automatic replacement.
589
+ return False
590
+
591
+ def process_matches(self, graph: fx.Graph):
592
+ """
593
+ Manually process multi-output matches and replace them with fused nodes.
594
+ See MultiOutputMatch for more details.
595
+ """
596
+ for match in self.matches:
597
+ match.process()
598
+
599
+ # Finally, remove matched nodes
600
+ graph.eliminate_dead_code()
601
+ assert all(node not in graph.nodes for match in self.matches
602
+ for node in match.match.nodes)
603
+
604
+ def __call__(self, graph: fx.Graph):
605
+ self.begin()
606
+ self.dump_graph(graph, "before_fusion")
607
+
608
+ count = self.patterns.apply(graph)
609
+ logger.debug("Replaced %s patterns", count)
610
+ self.dump_graph(graph, "after_pattern_match")
611
+
612
+ # Manually process multi-output matches (and run DCE)
613
+ self.process_matches(graph)
614
+ logger.debug("Post-processed %s matches", len(self.matches))
615
+ self.dump_graph(graph, "after_fusion")
616
+ self.matches.clear()
617
+ self.end_and_log()
.venv/lib/python3.11/site-packages/vllm/compilation/fx_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import operator
4
+ from typing import Iterable, Optional
5
+
6
+ from torch import fx
7
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
8
+ from torch._ops import OpOverload
9
+
10
+
11
+ def is_func(node: fx.Node, target) -> bool:
12
+ return node.op == "call_function" and node.target == target
13
+
14
+
15
+ # Returns the first auto_functionalized node with the given op (if it exists)
16
+ def find_auto_fn_maybe(nodes: Iterable[fx.Node],
17
+ op: OpOverload) -> Optional[fx.Node]:
18
+ for node in nodes:
19
+ if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
20
+ return node
21
+ return None
22
+
23
+
24
+ # Returns the first auto_functionalized node with the given op
25
+ def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
26
+ node = find_auto_fn_maybe(nodes, op)
27
+ assert node is not None, f"Could not find {op} in nodes {nodes}"
28
+ return node
29
+
30
+
31
+ # Returns the getitem node that extracts the idx-th element from node
32
+ # (if it exists)
33
+ def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
34
+ for user in node.users:
35
+ if is_func(user, operator.getitem) and user.args[1] == idx:
36
+ return user
37
+ return None
38
+
39
+
40
+ # Returns the getitem node that extracts the idx-th element from node
41
+ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
42
+ ret = find_getitem_maybe(node, idx)
43
+ assert ret is not None, f"Could not find getitem {idx} in node {node}"
44
+ return ret
.venv/lib/python3.11/site-packages/vllm/compilation/monitor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import time
5
+
6
+ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
7
+ from vllm.logger import init_logger
8
+
9
+ logger = init_logger(__name__)
10
+
11
+ context_manager = None
12
+ torch_compile_start_time: float = 0.0
13
+
14
+
15
+ def start_monitoring_torch_compile(vllm_config: VllmConfig):
16
+ global torch_compile_start_time
17
+ torch_compile_start_time = time.time()
18
+
19
+ compilation_config: CompilationConfig = vllm_config.compilation_config
20
+ if compilation_config.level == CompilationLevel.PIECEWISE and \
21
+ compilation_config.debug_dump_path:
22
+ import depyf
23
+ path = os.path.join(compilation_config.debug_dump_path,
24
+ f"rank_{vllm_config.parallel_config.rank}")
25
+ global context_manager
26
+ context_manager = depyf.prepare_debug(path)
27
+ context_manager.__enter__()
28
+
29
+
30
+ def end_monitoring_torch_compile(vllm_config: VllmConfig):
31
+ compilation_config: CompilationConfig = vllm_config.compilation_config
32
+ if compilation_config.level == CompilationLevel.PIECEWISE:
33
+ logger.info("torch.compile takes %.2f s in total",
34
+ compilation_config.compilation_time)
35
+ global context_manager
36
+ if context_manager is not None:
37
+ context_manager.__exit__(None, None, None)
38
+ context_manager = None
.venv/lib/python3.11/site-packages/vllm/compilation/pass_manager.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Any, Dict, List
4
+
5
+ from torch import fx as fx
6
+
7
+ from vllm.config import CompilationConfig
8
+ from vllm.logger import init_logger
9
+
10
+ from .fix_functionalization import FixFunctionalizationPass
11
+ from .fusion import FusionPass
12
+ from .inductor_pass import InductorPass
13
+ from .reshapes import RedundantReshapesPass
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class PostGradPassManager:
19
+ """
20
+ The pass manager for post-grad passes.
21
+ It handles configuration, adding custom passes, and running passes.
22
+ It also supports pickling, which is used by the Inductor code cache.
23
+ TODO(torch==2.6), use CustomGraphPass
24
+ (torch._inductor.custom_graph_pass.CustomGraphPass)
25
+
26
+ The order of the post-grad post-passes is:
27
+ 1. passes (constructor parameter)
28
+ 2. default passes (RedundantReshapesPass, FusionPass)
29
+ 3. config["post_grad_custom_post_pass"] (if it exists)
30
+ 4. fix_functionalization
31
+ This way, all passes operate on a functionalized graph.
32
+ """
33
+
34
+ def __init__(self):
35
+ self.passes: List[InductorPass] = []
36
+
37
+ def __call__(self, graph: fx.Graph):
38
+ for pass_ in self.passes:
39
+ pass_(graph)
40
+
41
+ # always run fix_functionalization last
42
+ self.fix_functionalization(graph)
43
+
44
+ def configure(self, pass_config: CompilationConfig.PassConfig):
45
+ self.pass_config = pass_config
46
+ if pass_config.enable_reshape:
47
+ self.passes += [RedundantReshapesPass(pass_config)]
48
+
49
+ if pass_config.enable_fusion:
50
+ self.passes += [FusionPass.instance(pass_config)]
51
+
52
+ self.fix_functionalization = FixFunctionalizationPass(pass_config)
53
+
54
+ def add(self, pass_: InductorPass):
55
+ assert isinstance(pass_, InductorPass)
56
+ self.passes.append(pass_)
57
+
58
+ def __getstate__(self) -> Dict[str, List[Any]]:
59
+ """
60
+ Custom pickling for the pass manager, as some passes cannot be pickled.
61
+ Pickling occurs because the pass manager is set as the value of
62
+ `config["post_grad_custom_post_pass"]` in the Inductor config.
63
+ The config is pickled to act as a key in the Inductor code cache.
64
+ Any other passes in the config are pickled as well.
65
+
66
+ TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
67
+ """
68
+ state = {"pass_config": self.pass_config.uuid(), "passes": []}
69
+ for pass_ in self.passes:
70
+ state["passes"].append(pass_.uuid())
71
+ state["passes"].append(self.fix_functionalization.uuid())
72
+ return state
73
+
74
+ def __setstate__(self, state):
75
+ """
76
+ Do not allow unpickling of the pass manager.
77
+ If this is needed in the future, it should properly pickle the passes.
78
+ """
79
+ raise ValueError("Cannot unpickle PostGradPassManager")
.venv/lib/python3.11/site-packages/vllm/compilation/wrapper.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import sys
5
+ from abc import abstractmethod
6
+ from contextlib import contextmanager
7
+ from types import CodeType
8
+ from typing import Callable, List, Optional
9
+
10
+ import torch
11
+
12
+ import vllm.envs as envs
13
+ from vllm.config import CompilationLevel, get_current_vllm_config
14
+ from vllm.logger import init_logger
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ class TorchCompileWrapperWithCustomDispatcher:
20
+ """
21
+ A wrapper class for torch.compile, with a custom dispatch logic.
22
+ Subclasses should:
23
+ 1. Implement the forward method
24
+ 2. Implement the dispatch logic in the __call__ method
25
+ It can use `self.compiled_codes` to access the compiled bytecode,
26
+ and `with self.dispatch_to_code(index):` to dispatch to
27
+ the compiled code.
28
+ 3. Implement the `__init__` method to determine how to call
29
+ `torch.compile` over the forward method.
30
+ """
31
+
32
+ def __init__(self,
33
+ compiled_callable: Optional[Callable] = None,
34
+ compilation_level: int = 0):
35
+
36
+ vllm_config = get_current_vllm_config()
37
+ self.vllm_config = vllm_config
38
+ if compiled_callable is None:
39
+ # default compilation settings
40
+ # compiling the forward method
41
+
42
+ backend = vllm_config.compilation_config.init_backend(vllm_config)
43
+
44
+ compiled_callable = torch.compile(
45
+ self.forward,
46
+ fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
47
+ backend=backend)
48
+
49
+ self.compiled_callable = compiled_callable
50
+ self.original_code_object = self.__class__.forward.__code__
51
+ self.compiled_codes: List[CodeType] = []
52
+ torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
53
+
54
+ # read the env var to determine whether to use the custom dispatcher
55
+ # subclasses can use this to switch between the custom dispatcher
56
+ # and the default Dynamo guard mechanism.
57
+ self.use_custom_dispatcher: bool = \
58
+ compilation_level >= CompilationLevel.DYNAMO_ONCE
59
+
60
+ def __call__(self, *args, **kwargs):
61
+ """Implement the dispatch logic here, beyond the torch.compile level.
62
+ NOTE: this function can have additional arguments beyond the forward
63
+ method, for directly dispatching to the compiled code.
64
+ """
65
+ return self.compiled_callable(*args, **kwargs)
66
+
67
+ @abstractmethod
68
+ def forward(self, *args, **kwargs):
69
+ ...
70
+
71
+ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
72
+ """Hook to save the compiled bytecode for direct execution."""
73
+ if old_code is not self.original_code_object:
74
+ return
75
+ # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
76
+ frame = sys._getframe()
77
+ while frame and frame.f_back:
78
+ frame = frame.f_back
79
+ code_name = frame.f_code.co_name
80
+ file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
81
+ if code_name == "_compile" and file_name == "convert_frame.py":
82
+ break
83
+ frame = frame.f_locals["frame"]
84
+ assert frame.f_code == old_code
85
+
86
+ if frame.f_locals["self"] is not self:
87
+ return
88
+
89
+ self.compiled_codes.append(new_code)
90
+ local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
91
+ if isinstance(local_cache_dir, str):
92
+ decompiled_file = os.path.join(local_cache_dir,
93
+ "transformed_code.py")
94
+ if not os.path.exists(decompiled_file):
95
+ try:
96
+ # usually the decompilation will succeed for most models,
97
+ # as we guarantee a full-graph compilation in Dynamo.
98
+ # but there's no 100% guarantee, since decompliation is
99
+ # not a reversible process.
100
+ import depyf
101
+ src = depyf.decompile(new_code)
102
+ with open(decompiled_file, "w") as f:
103
+ f.write(src)
104
+
105
+ logger.debug("Dynamo transformed code saved to %s",
106
+ decompiled_file)
107
+ except Exception:
108
+ pass
109
+
110
+ if self.vllm_config.compilation_config.use_cudagraph and \
111
+ "update" in new_code.co_names:
112
+ import depyf
113
+ src = depyf.decompile(new_code)
114
+ msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
115
+ raise RuntimeError(msg)
116
+
117
+ @contextmanager
118
+ def dispatch_to_code(self, index: int):
119
+ """Context manager to dispatch to the compiled code.
120
+ Why does this work? Because Dynamo guarantees that the compiled
121
+ bytecode has exactly the same arguments, cell variables, and free
122
+ variables as the original code. Therefore we can directly switch
123
+ the code object in the function and call it.
124
+
125
+ See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
126
+ """ # noqa
127
+ self.__class__.forward.__code__ = self.compiled_codes[index]
128
+ yield
129
+ self.__class__.forward.__code__ = self.original_code_object
.venv/lib/python3.11/site-packages/vllm/engine/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/arg_utils.cpython-311.pyc ADDED
Binary file (59 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_llm_engine.cpython-311.pyc ADDED
Binary file (56 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_timeout.cpython-311.pyc ADDED
Binary file (8.92 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/llm_engine.cpython-311.pyc ADDED
Binary file (84.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (34.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics_types.cpython-311.pyc ADDED
Binary file (5.35 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/protocol.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/arg_utils.py ADDED
@@ -0,0 +1,1360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import argparse
4
+ import dataclasses
5
+ import json
6
+ from dataclasses import dataclass
7
+ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
8
+ Tuple, Type, Union, cast, get_args)
9
+
10
+ import torch
11
+
12
+ import vllm.envs as envs
13
+ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
14
+ DecodingConfig, DeviceConfig, HfOverrides,
15
+ KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
16
+ ModelConfig, ModelImpl, ObservabilityConfig,
17
+ ParallelConfig, PoolerConfig, PromptAdapterConfig,
18
+ SchedulerConfig, SpeculativeConfig, TaskOption,
19
+ TokenizerPoolConfig, VllmConfig)
20
+ from vllm.executor.executor_base import ExecutorBase
21
+ from vllm.logger import init_logger
22
+ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
23
+ from vllm.transformers_utils.utils import check_gguf_file
24
+ from vllm.usage.usage_lib import UsageContext
25
+ from vllm.utils import FlexibleArgumentParser, StoreBoolean
26
+
27
+ if TYPE_CHECKING:
28
+ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
29
+
30
+ logger = init_logger(__name__)
31
+
32
+ ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
33
+
34
+ DEVICE_OPTIONS = [
35
+ "auto",
36
+ "cuda",
37
+ "neuron",
38
+ "cpu",
39
+ "openvino",
40
+ "tpu",
41
+ "xpu",
42
+ "hpu",
43
+ ]
44
+
45
+
46
+ def nullable_str(val: str):
47
+ if not val or val == "None":
48
+ return None
49
+ return val
50
+
51
+
52
+ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
53
+ """Parses a string containing comma separate key [str] to value [int]
54
+ pairs into a dictionary.
55
+
56
+ Args:
57
+ val: String value to be parsed.
58
+
59
+ Returns:
60
+ Dictionary with parsed values.
61
+ """
62
+ if len(val) == 0:
63
+ return None
64
+
65
+ out_dict: Dict[str, int] = {}
66
+ for item in val.split(","):
67
+ kv_parts = [part.lower().strip() for part in item.split("=")]
68
+ if len(kv_parts) != 2:
69
+ raise argparse.ArgumentTypeError(
70
+ "Each item should be in the form KEY=VALUE")
71
+ key, value = kv_parts
72
+
73
+ try:
74
+ parsed_value = int(value)
75
+ except ValueError as exc:
76
+ msg = f"Failed to parse value of item {key}={value}"
77
+ raise argparse.ArgumentTypeError(msg) from exc
78
+
79
+ if key in out_dict and out_dict[key] != parsed_value:
80
+ raise argparse.ArgumentTypeError(
81
+ f"Conflicting values specified for key: {key}")
82
+ out_dict[key] = parsed_value
83
+
84
+ return out_dict
85
+
86
+
87
+ @dataclass
88
+ class EngineArgs:
89
+ """Arguments for vLLM engine."""
90
+ model: str = 'facebook/opt-125m'
91
+ served_model_name: Optional[Union[str, List[str]]] = None
92
+ tokenizer: Optional[str] = None
93
+ task: TaskOption = "auto"
94
+ skip_tokenizer_init: bool = False
95
+ tokenizer_mode: str = 'auto'
96
+ trust_remote_code: bool = False
97
+ allowed_local_media_path: str = ""
98
+ download_dir: Optional[str] = None
99
+ load_format: str = 'auto'
100
+ config_format: ConfigFormat = ConfigFormat.AUTO
101
+ dtype: str = 'auto'
102
+ kv_cache_dtype: str = 'auto'
103
+ seed: int = 0
104
+ max_model_len: Optional[int] = None
105
+ # Note: Specifying a custom executor backend by passing a class
106
+ # is intended for expert use only. The API may change without
107
+ # notice.
108
+ distributed_executor_backend: Optional[Union[str,
109
+ Type[ExecutorBase]]] = None
110
+ # number of P/D disaggregation (or other disaggregation) workers
111
+ pipeline_parallel_size: int = 1
112
+ tensor_parallel_size: int = 1
113
+ max_parallel_loading_workers: Optional[int] = None
114
+ block_size: Optional[int] = None
115
+ enable_prefix_caching: Optional[bool] = None
116
+ disable_sliding_window: bool = False
117
+ use_v2_block_manager: bool = True
118
+ swap_space: float = 4 # GiB
119
+ cpu_offload_gb: float = 0 # GiB
120
+ gpu_memory_utilization: float = 0.90
121
+ max_num_batched_tokens: Optional[int] = None
122
+ max_num_seqs: Optional[int] = None
123
+ max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
124
+ disable_log_stats: bool = False
125
+ revision: Optional[str] = None
126
+ code_revision: Optional[str] = None
127
+ rope_scaling: Optional[Dict[str, Any]] = None
128
+ rope_theta: Optional[float] = None
129
+ hf_overrides: Optional[HfOverrides] = None
130
+ tokenizer_revision: Optional[str] = None
131
+ quantization: Optional[str] = None
132
+ enforce_eager: Optional[bool] = None
133
+ max_seq_len_to_capture: int = 8192
134
+ disable_custom_all_reduce: bool = False
135
+ tokenizer_pool_size: int = 0
136
+ # Note: Specifying a tokenizer pool by passing a class
137
+ # is intended for expert use only. The API may change without
138
+ # notice.
139
+ tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
140
+ tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
141
+ limit_mm_per_prompt: Optional[Mapping[str, int]] = None
142
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None
143
+ disable_mm_preprocessor_cache: bool = False
144
+ enable_lora: bool = False
145
+ enable_lora_bias: bool = False
146
+ max_loras: int = 1
147
+ max_lora_rank: int = 16
148
+ enable_prompt_adapter: bool = False
149
+ max_prompt_adapters: int = 1
150
+ max_prompt_adapter_token: int = 0
151
+ fully_sharded_loras: bool = False
152
+ lora_extra_vocab_size: int = 256
153
+ long_lora_scaling_factors: Optional[Tuple[float]] = None
154
+ lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
155
+ max_cpu_loras: Optional[int] = None
156
+ device: str = 'auto'
157
+ num_scheduler_steps: int = 1
158
+ multi_step_stream_outputs: bool = True
159
+ ray_workers_use_nsight: bool = False
160
+ num_gpu_blocks_override: Optional[int] = None
161
+ num_lookahead_slots: int = 0
162
+ model_loader_extra_config: Optional[dict] = None
163
+ ignore_patterns: Optional[Union[str, List[str]]] = None
164
+ preemption_mode: Optional[str] = None
165
+
166
+ scheduler_delay_factor: float = 0.0
167
+ enable_chunked_prefill: Optional[bool] = None
168
+
169
+ guided_decoding_backend: str = 'xgrammar'
170
+ logits_processor_pattern: Optional[str] = None
171
+ # Speculative decoding configuration.
172
+ speculative_model: Optional[str] = None
173
+ speculative_model_quantization: Optional[str] = None
174
+ speculative_draft_tensor_parallel_size: Optional[int] = None
175
+ num_speculative_tokens: Optional[int] = None
176
+ speculative_disable_mqa_scorer: Optional[bool] = False
177
+ speculative_max_model_len: Optional[int] = None
178
+ speculative_disable_by_batch_size: Optional[int] = None
179
+ ngram_prompt_lookup_max: Optional[int] = None
180
+ ngram_prompt_lookup_min: Optional[int] = None
181
+ spec_decoding_acceptance_method: str = 'rejection_sampler'
182
+ typical_acceptance_sampler_posterior_threshold: Optional[float] = None
183
+ typical_acceptance_sampler_posterior_alpha: Optional[float] = None
184
+ qlora_adapter_name_or_path: Optional[str] = None
185
+ disable_logprobs_during_spec_decoding: Optional[bool] = None
186
+
187
+ otlp_traces_endpoint: Optional[str] = None
188
+ collect_detailed_traces: Optional[str] = None
189
+ disable_async_output_proc: bool = False
190
+ scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
191
+
192
+ override_neuron_config: Optional[Dict[str, Any]] = None
193
+ override_pooler_config: Optional[PoolerConfig] = None
194
+ compilation_config: Optional[CompilationConfig] = None
195
+ worker_cls: str = "auto"
196
+
197
+ kv_transfer_config: Optional[KVTransferConfig] = None
198
+
199
+ generation_config: Optional[str] = None
200
+ override_generation_config: Optional[Dict[str, Any]] = None
201
+ enable_sleep_mode: bool = False
202
+ model_impl: str = "auto"
203
+
204
+ calculate_kv_scales: Optional[bool] = None
205
+
206
+ def __post_init__(self):
207
+ if not self.tokenizer:
208
+ self.tokenizer = self.model
209
+
210
+ # Override the default value of enable_prefix_caching if it's not set
211
+ # by user.
212
+ if self.enable_prefix_caching is None:
213
+ self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
214
+
215
+ # Override max_num_seqs if it's not set by user.
216
+ if self.max_num_seqs is None:
217
+ self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
218
+
219
+ # support `EngineArgs(compilation_config={...})`
220
+ # without having to manually construct a
221
+ # CompilationConfig object
222
+ if isinstance(self.compilation_config, (int, dict)):
223
+ self.compilation_config = CompilationConfig.from_cli(
224
+ str(self.compilation_config))
225
+
226
+ # Setup plugins
227
+ from vllm.plugins import load_general_plugins
228
+ load_general_plugins()
229
+
230
+ @staticmethod
231
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
232
+ """Shared CLI arguments for vLLM engine."""
233
+
234
+ # Model arguments
235
+ parser.add_argument(
236
+ '--model',
237
+ type=str,
238
+ default=EngineArgs.model,
239
+ help='Name or path of the huggingface model to use.')
240
+ parser.add_argument(
241
+ '--task',
242
+ default=EngineArgs.task,
243
+ choices=get_args(TaskOption),
244
+ help='The task to use the model for. Each vLLM instance only '
245
+ 'supports one task, even if the same model can be used for '
246
+ 'multiple tasks. When the model only supports one task, ``"auto"`` '
247
+ 'can be used to select it; otherwise, you must specify explicitly '
248
+ 'which task to use.')
249
+ parser.add_argument(
250
+ '--tokenizer',
251
+ type=nullable_str,
252
+ default=EngineArgs.tokenizer,
253
+ help='Name or path of the huggingface tokenizer to use. '
254
+ 'If unspecified, model name or path will be used.')
255
+ parser.add_argument(
256
+ '--skip-tokenizer-init',
257
+ action='store_true',
258
+ help='Skip initialization of tokenizer and detokenizer.')
259
+ parser.add_argument(
260
+ '--revision',
261
+ type=nullable_str,
262
+ default=None,
263
+ help='The specific model version to use. It can be a branch '
264
+ 'name, a tag name, or a commit id. If unspecified, will use '
265
+ 'the default version.')
266
+ parser.add_argument(
267
+ '--code-revision',
268
+ type=nullable_str,
269
+ default=None,
270
+ help='The specific revision to use for the model code on '
271
+ 'Hugging Face Hub. It can be a branch name, a tag name, or a '
272
+ 'commit id. If unspecified, will use the default version.')
273
+ parser.add_argument(
274
+ '--tokenizer-revision',
275
+ type=nullable_str,
276
+ default=None,
277
+ help='Revision of the huggingface tokenizer to use. '
278
+ 'It can be a branch name, a tag name, or a commit id. '
279
+ 'If unspecified, will use the default version.')
280
+ parser.add_argument(
281
+ '--tokenizer-mode',
282
+ type=str,
283
+ default=EngineArgs.tokenizer_mode,
284
+ choices=['auto', 'slow', 'mistral'],
285
+ help='The tokenizer mode.\n\n* "auto" will use the '
286
+ 'fast tokenizer if available.\n* "slow" will '
287
+ 'always use the slow tokenizer. \n* '
288
+ '"mistral" will always use the `mistral_common` tokenizer.')
289
+ parser.add_argument('--trust-remote-code',
290
+ action='store_true',
291
+ help='Trust remote code from huggingface.')
292
+ parser.add_argument(
293
+ '--allowed-local-media-path',
294
+ type=str,
295
+ help="Allowing API requests to read local images or videos "
296
+ "from directories specified by the server file system. "
297
+ "This is a security risk. "
298
+ "Should only be enabled in trusted environments.")
299
+ parser.add_argument('--download-dir',
300
+ type=nullable_str,
301
+ default=EngineArgs.download_dir,
302
+ help='Directory to download and load the weights, '
303
+ 'default to the default cache dir of '
304
+ 'huggingface.')
305
+ parser.add_argument(
306
+ '--load-format',
307
+ type=str,
308
+ default=EngineArgs.load_format,
309
+ choices=[f.value for f in LoadFormat],
310
+ help='The format of the model weights to load.\n\n'
311
+ '* "auto" will try to load the weights in the safetensors format '
312
+ 'and fall back to the pytorch bin format if safetensors format '
313
+ 'is not available.\n'
314
+ '* "pt" will load the weights in the pytorch bin format.\n'
315
+ '* "safetensors" will load the weights in the safetensors format.\n'
316
+ '* "npcache" will load the weights in pytorch format and store '
317
+ 'a numpy cache to speed up the loading.\n'
318
+ '* "dummy" will initialize the weights with random values, '
319
+ 'which is mainly for profiling.\n'
320
+ '* "tensorizer" will load the weights using tensorizer from '
321
+ 'CoreWeave. See the Tensorize vLLM Model script in the Examples '
322
+ 'section for more information.\n'
323
+ '* "runai_streamer" will load the Safetensors weights using Run:ai'
324
+ 'Model Streamer \n'
325
+ '* "bitsandbytes" will load the weights using bitsandbytes '
326
+ 'quantization.\n')
327
+ parser.add_argument(
328
+ '--config-format',
329
+ default=EngineArgs.config_format,
330
+ choices=[f.value for f in ConfigFormat],
331
+ help='The format of the model config to load.\n\n'
332
+ '* "auto" will try to load the config in hf format '
333
+ 'if available else it will try to load in mistral format ')
334
+ parser.add_argument(
335
+ '--dtype',
336
+ type=str,
337
+ default=EngineArgs.dtype,
338
+ choices=[
339
+ 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
340
+ ],
341
+ help='Data type for model weights and activations.\n\n'
342
+ '* "auto" will use FP16 precision for FP32 and FP16 models, and '
343
+ 'BF16 precision for BF16 models.\n'
344
+ '* "half" for FP16. Recommended for AWQ quantization.\n'
345
+ '* "float16" is the same as "half".\n'
346
+ '* "bfloat16" for a balance between precision and range.\n'
347
+ '* "float" is shorthand for FP32 precision.\n'
348
+ '* "float32" for FP32 precision.')
349
+ parser.add_argument(
350
+ '--kv-cache-dtype',
351
+ type=str,
352
+ choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
353
+ default=EngineArgs.kv_cache_dtype,
354
+ help='Data type for kv cache storage. If "auto", will use model '
355
+ 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
356
+ 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
357
+ parser.add_argument('--max-model-len',
358
+ type=int,
359
+ default=EngineArgs.max_model_len,
360
+ help='Model context length. If unspecified, will '
361
+ 'be automatically derived from the model config.')
362
+ parser.add_argument(
363
+ '--guided-decoding-backend',
364
+ type=str,
365
+ default='xgrammar',
366
+ choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
367
+ help='Which engine will be used for guided decoding'
368
+ ' (JSON schema / regex etc) by default. Currently support '
369
+ 'https://github.com/outlines-dev/outlines, '
370
+ 'https://github.com/mlc-ai/xgrammar, and '
371
+ 'https://github.com/noamgat/lm-format-enforcer.'
372
+ ' Can be overridden per request via guided_decoding_backend'
373
+ ' parameter.')
374
+ parser.add_argument(
375
+ '--logits-processor-pattern',
376
+ type=nullable_str,
377
+ default=None,
378
+ help='Optional regex pattern specifying valid logits processor '
379
+ 'qualified names that can be passed with the `logits_processors` '
380
+ 'extra completion argument. Defaults to None, which allows no '
381
+ 'processors.')
382
+ parser.add_argument(
383
+ '--model-impl',
384
+ type=str,
385
+ default=EngineArgs.model_impl,
386
+ choices=[f.value for f in ModelImpl],
387
+ help='Which implementation of the model to use.\n\n'
388
+ '* "auto" will try to use the vLLM implementation if it exists '
389
+ 'and fall back to the Transformers implementation if no vLLM '
390
+ 'implementation is available.\n'
391
+ '* "vllm" will use the vLLM model implementation.\n'
392
+ '* "transformers" will use the Transformers model '
393
+ 'implementation.\n')
394
+ # Parallel arguments
395
+ parser.add_argument(
396
+ '--distributed-executor-backend',
397
+ choices=['ray', 'mp', 'uni', 'external_launcher'],
398
+ default=EngineArgs.distributed_executor_backend,
399
+ help='Backend to use for distributed model '
400
+ 'workers, either "ray" or "mp" (multiprocessing). If the product '
401
+ 'of pipeline_parallel_size and tensor_parallel_size is less than '
402
+ 'or equal to the number of GPUs available, "mp" will be used to '
403
+ 'keep processing on a single host. Otherwise, this will default '
404
+ 'to "ray" if Ray is installed and fail otherwise. Note that tpu '
405
+ 'only supports Ray for distributed inference.')
406
+
407
+ parser.add_argument('--pipeline-parallel-size',
408
+ '-pp',
409
+ type=int,
410
+ default=EngineArgs.pipeline_parallel_size,
411
+ help='Number of pipeline stages.')
412
+ parser.add_argument('--tensor-parallel-size',
413
+ '-tp',
414
+ type=int,
415
+ default=EngineArgs.tensor_parallel_size,
416
+ help='Number of tensor parallel replicas.')
417
+ parser.add_argument(
418
+ '--max-parallel-loading-workers',
419
+ type=int,
420
+ default=EngineArgs.max_parallel_loading_workers,
421
+ help='Load model sequentially in multiple batches, '
422
+ 'to avoid RAM OOM when using tensor '
423
+ 'parallel and large models.')
424
+ parser.add_argument(
425
+ '--ray-workers-use-nsight',
426
+ action='store_true',
427
+ help='If specified, use nsight to profile Ray workers.')
428
+ # KV cache arguments
429
+ parser.add_argument('--block-size',
430
+ type=int,
431
+ default=EngineArgs.block_size,
432
+ choices=[8, 16, 32, 64, 128],
433
+ help='Token block size for contiguous chunks of '
434
+ 'tokens. This is ignored on neuron devices and '
435
+ 'set to ``--max-model-len``. On CUDA devices, '
436
+ 'only block sizes up to 32 are supported. '
437
+ 'On HPU devices, block size defaults to 128.')
438
+
439
+ parser.add_argument(
440
+ "--enable-prefix-caching",
441
+ action=argparse.BooleanOptionalAction,
442
+ default=EngineArgs.enable_prefix_caching,
443
+ help="Enables automatic prefix caching. "
444
+ "Use ``--no-enable-prefix-caching`` to disable explicitly.",
445
+ )
446
+ parser.add_argument('--disable-sliding-window',
447
+ action='store_true',
448
+ help='Disables sliding window, '
449
+ 'capping to sliding window size.')
450
+ parser.add_argument('--use-v2-block-manager',
451
+ action='store_true',
452
+ default=True,
453
+ help='[DEPRECATED] block manager v1 has been '
454
+ 'removed and SelfAttnBlockSpaceManager (i.e. '
455
+ 'block manager v2) is now the default. '
456
+ 'Setting this flag to True or False'
457
+ ' has no effect on vLLM behavior.')
458
+ parser.add_argument(
459
+ '--num-lookahead-slots',
460
+ type=int,
461
+ default=EngineArgs.num_lookahead_slots,
462
+ help='Experimental scheduling config necessary for '
463
+ 'speculative decoding. This will be replaced by '
464
+ 'speculative config in the future; it is present '
465
+ 'to enable correctness tests until then.')
466
+
467
+ parser.add_argument('--seed',
468
+ type=int,
469
+ default=EngineArgs.seed,
470
+ help='Random seed for operations.')
471
+ parser.add_argument('--swap-space',
472
+ type=float,
473
+ default=EngineArgs.swap_space,
474
+ help='CPU swap space size (GiB) per GPU.')
475
+ parser.add_argument(
476
+ '--cpu-offload-gb',
477
+ type=float,
478
+ default=0,
479
+ help='The space in GiB to offload to CPU, per GPU. '
480
+ 'Default is 0, which means no offloading. Intuitively, '
481
+ 'this argument can be seen as a virtual way to increase '
482
+ 'the GPU memory size. For example, if you have one 24 GB '
483
+ 'GPU and set this to 10, virtually you can think of it as '
484
+ 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
485
+ 'which requires at least 26GB GPU memory. Note that this '
486
+ 'requires fast CPU-GPU interconnect, as part of the model is '
487
+ 'loaded from CPU memory to GPU memory on the fly in each '
488
+ 'model forward pass.')
489
+ parser.add_argument(
490
+ '--gpu-memory-utilization',
491
+ type=float,
492
+ default=EngineArgs.gpu_memory_utilization,
493
+ help='The fraction of GPU memory to be used for the model '
494
+ 'executor, which can range from 0 to 1. For example, a value of '
495
+ '0.5 would imply 50%% GPU memory utilization. If unspecified, '
496
+ 'will use the default value of 0.9. This is a per-instance '
497
+ 'limit, and only applies to the current vLLM instance.'
498
+ 'It does not matter if you have another vLLM instance running '
499
+ 'on the same GPU. For example, if you have two vLLM instances '
500
+ 'running on the same GPU, you can set the GPU memory utilization '
501
+ 'to 0.5 for each instance.')
502
+ parser.add_argument(
503
+ '--num-gpu-blocks-override',
504
+ type=int,
505
+ default=None,
506
+ help='If specified, ignore GPU profiling result and use this number'
507
+ ' of GPU blocks. Used for testing preemption.')
508
+ parser.add_argument('--max-num-batched-tokens',
509
+ type=int,
510
+ default=EngineArgs.max_num_batched_tokens,
511
+ help='Maximum number of batched tokens per '
512
+ 'iteration.')
513
+ parser.add_argument('--max-num-seqs',
514
+ type=int,
515
+ default=EngineArgs.max_num_seqs,
516
+ help='Maximum number of sequences per iteration.')
517
+ parser.add_argument(
518
+ '--max-logprobs',
519
+ type=int,
520
+ default=EngineArgs.max_logprobs,
521
+ help=('Max number of log probs to return logprobs is specified in'
522
+ ' SamplingParams.'))
523
+ parser.add_argument('--disable-log-stats',
524
+ action='store_true',
525
+ help='Disable logging statistics.')
526
+ # Quantization settings.
527
+ parser.add_argument('--quantization',
528
+ '-q',
529
+ type=nullable_str,
530
+ choices=[*QUANTIZATION_METHODS, None],
531
+ default=EngineArgs.quantization,
532
+ help='Method used to quantize the weights. If '
533
+ 'None, we first check the `quantization_config` '
534
+ 'attribute in the model config file. If that is '
535
+ 'None, we assume the model weights are not '
536
+ 'quantized and use `dtype` to determine the data '
537
+ 'type of the weights.')
538
+ parser.add_argument(
539
+ '--rope-scaling',
540
+ default=None,
541
+ type=json.loads,
542
+ help='RoPE scaling configuration in JSON format. '
543
+ 'For example, ``{"rope_type":"dynamic","factor":2.0}``')
544
+ parser.add_argument('--rope-theta',
545
+ default=None,
546
+ type=float,
547
+ help='RoPE theta. Use with `rope_scaling`. In '
548
+ 'some cases, changing the RoPE theta improves the '
549
+ 'performance of the scaled model.')
550
+ parser.add_argument('--hf-overrides',
551
+ type=json.loads,
552
+ default=EngineArgs.hf_overrides,
553
+ help='Extra arguments for the HuggingFace config. '
554
+ 'This should be a JSON string that will be '
555
+ 'parsed into a dictionary.')
556
+ parser.add_argument('--enforce-eager',
557
+ action='store_true',
558
+ help='Always use eager-mode PyTorch. If False, '
559
+ 'will use eager mode and CUDA graph in hybrid '
560
+ 'for maximal performance and flexibility.')
561
+ parser.add_argument('--max-seq-len-to-capture',
562
+ type=int,
563
+ default=EngineArgs.max_seq_len_to_capture,
564
+ help='Maximum sequence length covered by CUDA '
565
+ 'graphs. When a sequence has context length '
566
+ 'larger than this, we fall back to eager mode. '
567
+ 'Additionally for encoder-decoder models, if the '
568
+ 'sequence length of the encoder input is larger '
569
+ 'than this, we fall back to the eager mode.')
570
+ parser.add_argument('--disable-custom-all-reduce',
571
+ action='store_true',
572
+ default=EngineArgs.disable_custom_all_reduce,
573
+ help='See ParallelConfig.')
574
+ parser.add_argument('--tokenizer-pool-size',
575
+ type=int,
576
+ default=EngineArgs.tokenizer_pool_size,
577
+ help='Size of tokenizer pool to use for '
578
+ 'asynchronous tokenization. If 0, will '
579
+ 'use synchronous tokenization.')
580
+ parser.add_argument('--tokenizer-pool-type',
581
+ type=str,
582
+ default=EngineArgs.tokenizer_pool_type,
583
+ help='Type of tokenizer pool to use for '
584
+ 'asynchronous tokenization. Ignored '
585
+ 'if tokenizer_pool_size is 0.')
586
+ parser.add_argument('--tokenizer-pool-extra-config',
587
+ type=nullable_str,
588
+ default=EngineArgs.tokenizer_pool_extra_config,
589
+ help='Extra config for tokenizer pool. '
590
+ 'This should be a JSON string that will be '
591
+ 'parsed into a dictionary. Ignored if '
592
+ 'tokenizer_pool_size is 0.')
593
+
594
+ # Multimodal related configs
595
+ parser.add_argument(
596
+ '--limit-mm-per-prompt',
597
+ type=nullable_kvs,
598
+ default=EngineArgs.limit_mm_per_prompt,
599
+ # The default value is given in
600
+ # MultiModalRegistry.init_mm_limits_per_prompt
601
+ help=('For each multimodal plugin, limit how many '
602
+ 'input instances to allow for each prompt. '
603
+ 'Expects a comma-separated list of items, '
604
+ 'e.g.: `image=16,video=2` allows a maximum of 16 '
605
+ 'images and 2 videos per prompt. Defaults to 1 for '
606
+ 'each modality.'))
607
+ parser.add_argument(
608
+ '--mm-processor-kwargs',
609
+ default=None,
610
+ type=json.loads,
611
+ help=('Overrides for the multimodal input mapping/processing, '
612
+ 'e.g., image processor. For example: ``{"num_crops": 4}``.'))
613
+ parser.add_argument(
614
+ '--disable-mm-preprocessor-cache',
615
+ action='store_true',
616
+ help='If true, then disables caching of the multi-modal '
617
+ 'preprocessor/mapper. (not recommended)')
618
+
619
+ # LoRA related configs
620
+ parser.add_argument('--enable-lora',
621
+ action='store_true',
622
+ help='If True, enable handling of LoRA adapters.')
623
+ parser.add_argument('--enable-lora-bias',
624
+ action='store_true',
625
+ help='If True, enable bias for LoRA adapters.')
626
+ parser.add_argument('--max-loras',
627
+ type=int,
628
+ default=EngineArgs.max_loras,
629
+ help='Max number of LoRAs in a single batch.')
630
+ parser.add_argument('--max-lora-rank',
631
+ type=int,
632
+ default=EngineArgs.max_lora_rank,
633
+ help='Max LoRA rank.')
634
+ parser.add_argument(
635
+ '--lora-extra-vocab-size',
636
+ type=int,
637
+ default=EngineArgs.lora_extra_vocab_size,
638
+ help=('Maximum size of extra vocabulary that can be '
639
+ 'present in a LoRA adapter (added to the base '
640
+ 'model vocabulary).'))
641
+ parser.add_argument(
642
+ '--lora-dtype',
643
+ type=str,
644
+ default=EngineArgs.lora_dtype,
645
+ choices=['auto', 'float16', 'bfloat16'],
646
+ help=('Data type for LoRA. If auto, will default to '
647
+ 'base model dtype.'))
648
+ parser.add_argument(
649
+ '--long-lora-scaling-factors',
650
+ type=nullable_str,
651
+ default=EngineArgs.long_lora_scaling_factors,
652
+ help=('Specify multiple scaling factors (which can '
653
+ 'be different from base model scaling factor '
654
+ '- see eg. Long LoRA) to allow for multiple '
655
+ 'LoRA adapters trained with those scaling '
656
+ 'factors to be used at the same time. If not '
657
+ 'specified, only adapters trained with the '
658
+ 'base model scaling factor are allowed.'))
659
+ parser.add_argument(
660
+ '--max-cpu-loras',
661
+ type=int,
662
+ default=EngineArgs.max_cpu_loras,
663
+ help=('Maximum number of LoRAs to store in CPU memory. '
664
+ 'Must be >= than max_loras. '
665
+ 'Defaults to max_loras.'))
666
+ parser.add_argument(
667
+ '--fully-sharded-loras',
668
+ action='store_true',
669
+ help=('By default, only half of the LoRA computation is '
670
+ 'sharded with tensor parallelism. '
671
+ 'Enabling this will use the fully sharded layers. '
672
+ 'At high sequence length, max rank or '
673
+ 'tensor parallel size, this is likely faster.'))
674
+ parser.add_argument('--enable-prompt-adapter',
675
+ action='store_true',
676
+ help='If True, enable handling of PromptAdapters.')
677
+ parser.add_argument('--max-prompt-adapters',
678
+ type=int,
679
+ default=EngineArgs.max_prompt_adapters,
680
+ help='Max number of PromptAdapters in a batch.')
681
+ parser.add_argument('--max-prompt-adapter-token',
682
+ type=int,
683
+ default=EngineArgs.max_prompt_adapter_token,
684
+ help='Max number of PromptAdapters tokens')
685
+ parser.add_argument("--device",
686
+ type=str,
687
+ default=EngineArgs.device,
688
+ choices=DEVICE_OPTIONS,
689
+ help='Device type for vLLM execution.')
690
+ parser.add_argument('--num-scheduler-steps',
691
+ type=int,
692
+ default=1,
693
+ help=('Maximum number of forward steps per '
694
+ 'scheduler call.'))
695
+
696
+ parser.add_argument(
697
+ '--multi-step-stream-outputs',
698
+ action=StoreBoolean,
699
+ default=EngineArgs.multi_step_stream_outputs,
700
+ nargs="?",
701
+ const="True",
702
+ help='If False, then multi-step will stream outputs at the end '
703
+ 'of all steps')
704
+ parser.add_argument(
705
+ '--scheduler-delay-factor',
706
+ type=float,
707
+ default=EngineArgs.scheduler_delay_factor,
708
+ help='Apply a delay (of delay factor multiplied by previous '
709
+ 'prompt latency) before scheduling next prompt.')
710
+ parser.add_argument(
711
+ '--enable-chunked-prefill',
712
+ action=StoreBoolean,
713
+ default=EngineArgs.enable_chunked_prefill,
714
+ nargs="?",
715
+ const="True",
716
+ help='If set, the prefill requests can be chunked based on the '
717
+ 'max_num_batched_tokens.')
718
+
719
+ parser.add_argument(
720
+ '--speculative-model',
721
+ type=nullable_str,
722
+ default=EngineArgs.speculative_model,
723
+ help=
724
+ 'The name of the draft model to be used in speculative decoding.')
725
+ # Quantization settings for speculative model.
726
+ parser.add_argument(
727
+ '--speculative-model-quantization',
728
+ type=nullable_str,
729
+ choices=[*QUANTIZATION_METHODS, None],
730
+ default=EngineArgs.speculative_model_quantization,
731
+ help='Method used to quantize the weights of speculative model. '
732
+ 'If None, we first check the `quantization_config` '
733
+ 'attribute in the model config file. If that is '
734
+ 'None, we assume the model weights are not '
735
+ 'quantized and use `dtype` to determine the data '
736
+ 'type of the weights.')
737
+ parser.add_argument(
738
+ '--num-speculative-tokens',
739
+ type=int,
740
+ default=EngineArgs.num_speculative_tokens,
741
+ help='The number of speculative tokens to sample from '
742
+ 'the draft model in speculative decoding.')
743
+ parser.add_argument(
744
+ '--speculative-disable-mqa-scorer',
745
+ action='store_true',
746
+ help=
747
+ 'If set to True, the MQA scorer will be disabled in speculative '
748
+ ' and fall back to batch expansion')
749
+ parser.add_argument(
750
+ '--speculative-draft-tensor-parallel-size',
751
+ '-spec-draft-tp',
752
+ type=int,
753
+ default=EngineArgs.speculative_draft_tensor_parallel_size,
754
+ help='Number of tensor parallel replicas for '
755
+ 'the draft model in speculative decoding.')
756
+
757
+ parser.add_argument(
758
+ '--speculative-max-model-len',
759
+ type=int,
760
+ default=EngineArgs.speculative_max_model_len,
761
+ help='The maximum sequence length supported by the '
762
+ 'draft model. Sequences over this length will skip '
763
+ 'speculation.')
764
+
765
+ parser.add_argument(
766
+ '--speculative-disable-by-batch-size',
767
+ type=int,
768
+ default=EngineArgs.speculative_disable_by_batch_size,
769
+ help='Disable speculative decoding for new incoming requests '
770
+ 'if the number of enqueue requests is larger than this value.')
771
+
772
+ parser.add_argument(
773
+ '--ngram-prompt-lookup-max',
774
+ type=int,
775
+ default=EngineArgs.ngram_prompt_lookup_max,
776
+ help='Max size of window for ngram prompt lookup in speculative '
777
+ 'decoding.')
778
+
779
+ parser.add_argument(
780
+ '--ngram-prompt-lookup-min',
781
+ type=int,
782
+ default=EngineArgs.ngram_prompt_lookup_min,
783
+ help='Min size of window for ngram prompt lookup in speculative '
784
+ 'decoding.')
785
+
786
+ parser.add_argument(
787
+ '--spec-decoding-acceptance-method',
788
+ type=str,
789
+ default=EngineArgs.spec_decoding_acceptance_method,
790
+ choices=['rejection_sampler', 'typical_acceptance_sampler'],
791
+ help='Specify the acceptance method to use during draft token '
792
+ 'verification in speculative decoding. Two types of acceptance '
793
+ 'routines are supported: '
794
+ '1) RejectionSampler which does not allow changing the '
795
+ 'acceptance rate of draft tokens, '
796
+ '2) TypicalAcceptanceSampler which is configurable, allowing for '
797
+ 'a higher acceptance rate at the cost of lower quality, '
798
+ 'and vice versa.')
799
+
800
+ parser.add_argument(
801
+ '--typical-acceptance-sampler-posterior-threshold',
802
+ type=float,
803
+ default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
804
+ help='Set the lower bound threshold for the posterior '
805
+ 'probability of a token to be accepted. This threshold is '
806
+ 'used by the TypicalAcceptanceSampler to make sampling decisions '
807
+ 'during speculative decoding. Defaults to 0.09')
808
+
809
+ parser.add_argument(
810
+ '--typical-acceptance-sampler-posterior-alpha',
811
+ type=float,
812
+ default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
813
+ help='A scaling factor for the entropy-based threshold for token '
814
+ 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
815
+ 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
816
+ 'i.e. 0.3')
817
+
818
+ parser.add_argument(
819
+ '--disable-logprobs-during-spec-decoding',
820
+ action=StoreBoolean,
821
+ default=EngineArgs.disable_logprobs_during_spec_decoding,
822
+ nargs="?",
823
+ const="True",
824
+ help='If set to True, token log probabilities are not returned '
825
+ 'during speculative decoding. If set to False, log probabilities '
826
+ 'are returned according to the settings in SamplingParams. If '
827
+ 'not specified, it defaults to True. Disabling log probabilities '
828
+ 'during speculative decoding reduces latency by skipping logprob '
829
+ 'calculation in proposal sampling, target sampling, and after '
830
+ 'accepted tokens are determined.')
831
+
832
+ parser.add_argument('--model-loader-extra-config',
833
+ type=nullable_str,
834
+ default=EngineArgs.model_loader_extra_config,
835
+ help='Extra config for model loader. '
836
+ 'This will be passed to the model loader '
837
+ 'corresponding to the chosen load_format. '
838
+ 'This should be a JSON string that will be '
839
+ 'parsed into a dictionary.')
840
+ parser.add_argument(
841
+ '--ignore-patterns',
842
+ action="append",
843
+ type=str,
844
+ default=[],
845
+ help="The pattern(s) to ignore when loading the model."
846
+ "Default to `original/**/*` to avoid repeated loading of llama's "
847
+ "checkpoints.")
848
+ parser.add_argument(
849
+ '--preemption-mode',
850
+ type=str,
851
+ default=None,
852
+ help='If \'recompute\', the engine performs preemption by '
853
+ 'recomputing; If \'swap\', the engine performs preemption by '
854
+ 'block swapping.')
855
+
856
+ parser.add_argument(
857
+ "--served-model-name",
858
+ nargs="+",
859
+ type=str,
860
+ default=None,
861
+ help="The model name(s) used in the API. If multiple "
862
+ "names are provided, the server will respond to any "
863
+ "of the provided names. The model name in the model "
864
+ "field of a response will be the first name in this "
865
+ "list. If not specified, the model name will be the "
866
+ "same as the ``--model`` argument. Noted that this name(s) "
867
+ "will also be used in `model_name` tag content of "
868
+ "prometheus metrics, if multiple names provided, metrics "
869
+ "tag will take the first one.")
870
+ parser.add_argument('--qlora-adapter-name-or-path',
871
+ type=str,
872
+ default=None,
873
+ help='Name or path of the QLoRA adapter.')
874
+
875
+ parser.add_argument(
876
+ '--otlp-traces-endpoint',
877
+ type=str,
878
+ default=None,
879
+ help='Target URL to which OpenTelemetry traces will be sent.')
880
+ parser.add_argument(
881
+ '--collect-detailed-traces',
882
+ type=str,
883
+ default=None,
884
+ help="Valid choices are " +
885
+ ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
886
+ ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
887
+ " set. If set, it will collect detailed traces for the specified "
888
+ "modules. This involves use of possibly costly and or blocking "
889
+ "operations and hence might have a performance impact.")
890
+
891
+ parser.add_argument(
892
+ '--disable-async-output-proc',
893
+ action='store_true',
894
+ default=EngineArgs.disable_async_output_proc,
895
+ help="Disable async output processing. This may result in "
896
+ "lower performance.")
897
+
898
+ parser.add_argument(
899
+ '--scheduling-policy',
900
+ choices=['fcfs', 'priority'],
901
+ default="fcfs",
902
+ help='The scheduling policy to use. "fcfs" (first come first served'
903
+ ', i.e. requests are handled in order of arrival; default) '
904
+ 'or "priority" (requests are handled based on given '
905
+ 'priority (lower value means earlier handling) and time of '
906
+ 'arrival deciding any ties).')
907
+
908
+ parser.add_argument(
909
+ '--override-neuron-config',
910
+ type=json.loads,
911
+ default=None,
912
+ help="Override or set neuron device configuration. "
913
+ "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
914
+ parser.add_argument(
915
+ '--override-pooler-config',
916
+ type=PoolerConfig.from_json,
917
+ default=None,
918
+ help="Override or set the pooling method for pooling models. "
919
+ "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
920
+
921
+ parser.add_argument('--compilation-config',
922
+ '-O',
923
+ type=CompilationConfig.from_cli,
924
+ default=None,
925
+ help='torch.compile configuration for the model.'
926
+ 'When it is a number (0, 1, 2, 3), it will be '
927
+ 'interpreted as the optimization level.\n'
928
+ 'NOTE: level 0 is the default level without '
929
+ 'any optimization. level 1 and 2 are for internal '
930
+ 'testing only. level 3 is the recommended level '
931
+ 'for production.\n'
932
+ 'To specify the full compilation config, '
933
+ 'use a JSON string.\n'
934
+ 'Following the convention of traditional '
935
+ 'compilers, using -O without space is also '
936
+ 'supported. -O3 is equivalent to -O 3.')
937
+
938
+ parser.add_argument('--kv-transfer-config',
939
+ type=KVTransferConfig.from_cli,
940
+ default=None,
941
+ help='The configurations for distributed KV cache '
942
+ 'transfer. Should be a JSON string.')
943
+
944
+ parser.add_argument(
945
+ '--worker-cls',
946
+ type=str,
947
+ default="auto",
948
+ help='The worker class to use for distributed execution.')
949
+ parser.add_argument(
950
+ "--generation-config",
951
+ type=nullable_str,
952
+ default=None,
953
+ help="The folder path to the generation config. "
954
+ "Defaults to None, no generation config is loaded, vLLM defaults "
955
+ "will be used. If set to 'auto', the generation config will be "
956
+ "loaded from model path. If set to a folder path, the generation "
957
+ "config will be loaded from the specified folder path. If "
958
+ "`max_new_tokens` is specified in generation config, then "
959
+ "it sets a server-wide limit on the number of output tokens "
960
+ "for all requests.")
961
+
962
+ parser.add_argument(
963
+ "--override-generation-config",
964
+ type=json.loads,
965
+ default=None,
966
+ help="Overrides or sets generation config in JSON format. "
967
+ "e.g. ``{\"temperature\": 0.5}``. If used with "
968
+ "--generation-config=auto, the override parameters will be merged "
969
+ "with the default config from the model. If generation-config is "
970
+ "None, only the override parameters are used.")
971
+
972
+ parser.add_argument("--enable-sleep-mode",
973
+ action="store_true",
974
+ default=False,
975
+ help="Enable sleep mode for the engine. "
976
+ "(only cuda platform is supported)")
977
+
978
+ parser.add_argument(
979
+ '--calculate-kv-scales',
980
+ action='store_true',
981
+ help='This enables dynamic calculation of '
982
+ 'k_scale and v_scale when kv-cache-dtype is fp8. '
983
+ 'If calculate-kv-scales is false, the scales will '
984
+ 'be loaded from the model checkpoint if available. '
985
+ 'Otherwise, the scales will default to 1.0.')
986
+
987
+ return parser
988
+
989
+ @classmethod
990
+ def from_cli_args(cls, args: argparse.Namespace):
991
+ # Get the list of attributes of this dataclass.
992
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
993
+ # Set the attributes from the parsed arguments.
994
+ engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
995
+ return engine_args
996
+
997
+ def create_model_config(self) -> ModelConfig:
998
+ return ModelConfig(
999
+ model=self.model,
1000
+ task=self.task,
1001
+ # We know this is not None because we set it in __post_init__
1002
+ tokenizer=cast(str, self.tokenizer),
1003
+ tokenizer_mode=self.tokenizer_mode,
1004
+ trust_remote_code=self.trust_remote_code,
1005
+ allowed_local_media_path=self.allowed_local_media_path,
1006
+ dtype=self.dtype,
1007
+ seed=self.seed,
1008
+ revision=self.revision,
1009
+ code_revision=self.code_revision,
1010
+ rope_scaling=self.rope_scaling,
1011
+ rope_theta=self.rope_theta,
1012
+ hf_overrides=self.hf_overrides,
1013
+ tokenizer_revision=self.tokenizer_revision,
1014
+ max_model_len=self.max_model_len,
1015
+ quantization=self.quantization,
1016
+ enforce_eager=self.enforce_eager,
1017
+ max_seq_len_to_capture=self.max_seq_len_to_capture,
1018
+ max_logprobs=self.max_logprobs,
1019
+ disable_sliding_window=self.disable_sliding_window,
1020
+ skip_tokenizer_init=self.skip_tokenizer_init,
1021
+ served_model_name=self.served_model_name,
1022
+ limit_mm_per_prompt=self.limit_mm_per_prompt,
1023
+ use_async_output_proc=not self.disable_async_output_proc,
1024
+ config_format=self.config_format,
1025
+ mm_processor_kwargs=self.mm_processor_kwargs,
1026
+ disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1027
+ override_neuron_config=self.override_neuron_config,
1028
+ override_pooler_config=self.override_pooler_config,
1029
+ logits_processor_pattern=self.logits_processor_pattern,
1030
+ generation_config=self.generation_config,
1031
+ override_generation_config=self.override_generation_config,
1032
+ enable_sleep_mode=self.enable_sleep_mode,
1033
+ model_impl=self.model_impl,
1034
+ )
1035
+
1036
+ def create_load_config(self) -> LoadConfig:
1037
+ return LoadConfig(
1038
+ load_format=self.load_format,
1039
+ download_dir=self.download_dir,
1040
+ model_loader_extra_config=self.model_loader_extra_config,
1041
+ ignore_patterns=self.ignore_patterns,
1042
+ )
1043
+
1044
+ def create_engine_config(self,
1045
+ usage_context: Optional[UsageContext] = None
1046
+ ) -> VllmConfig:
1047
+ if envs.VLLM_USE_V1:
1048
+ self._override_v1_engine_args(usage_context)
1049
+
1050
+ # gguf file needs a specific model loader and doesn't use hf_repo
1051
+ if check_gguf_file(self.model):
1052
+ self.quantization = self.load_format = "gguf"
1053
+
1054
+ # bitsandbytes quantization needs a specific model loader
1055
+ # so we make sure the quant method and the load format are consistent
1056
+ if (self.quantization == "bitsandbytes" or
1057
+ self.qlora_adapter_name_or_path is not None) and \
1058
+ self.load_format != "bitsandbytes":
1059
+ raise ValueError(
1060
+ "BitsAndBytes quantization and QLoRA adapter only support "
1061
+ f"'bitsandbytes' load format, but got {self.load_format}")
1062
+
1063
+ if (self.load_format == "bitsandbytes" or
1064
+ self.qlora_adapter_name_or_path is not None) and \
1065
+ self.quantization != "bitsandbytes":
1066
+ raise ValueError(
1067
+ "BitsAndBytes load format and QLoRA adapter only support "
1068
+ f"'bitsandbytes' quantization, but got {self.quantization}")
1069
+
1070
+ assert self.cpu_offload_gb >= 0, (
1071
+ "CPU offload space must be non-negative"
1072
+ f", but got {self.cpu_offload_gb}")
1073
+
1074
+ device_config = DeviceConfig(device=self.device)
1075
+ model_config = self.create_model_config()
1076
+
1077
+ if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
1078
+ and self.enable_prefix_caching):
1079
+ logger.warning("--enable-prefix-caching is currently not "
1080
+ "supported for multimodal models in v0 and "
1081
+ "has been disabled.")
1082
+ self.enable_prefix_caching = False
1083
+
1084
+ cache_config = CacheConfig(
1085
+ block_size=self.block_size,
1086
+ gpu_memory_utilization=self.gpu_memory_utilization,
1087
+ swap_space=self.swap_space,
1088
+ cache_dtype=self.kv_cache_dtype,
1089
+ is_attention_free=model_config.is_attention_free,
1090
+ num_gpu_blocks_override=self.num_gpu_blocks_override,
1091
+ sliding_window=model_config.get_sliding_window(),
1092
+ enable_prefix_caching=self.enable_prefix_caching,
1093
+ cpu_offload_gb=self.cpu_offload_gb,
1094
+ calculate_kv_scales=self.calculate_kv_scales,
1095
+ )
1096
+ parallel_config = ParallelConfig(
1097
+ pipeline_parallel_size=self.pipeline_parallel_size,
1098
+ tensor_parallel_size=self.tensor_parallel_size,
1099
+ max_parallel_loading_workers=self.max_parallel_loading_workers,
1100
+ disable_custom_all_reduce=self.disable_custom_all_reduce,
1101
+ tokenizer_pool_config=TokenizerPoolConfig.create_config(
1102
+ self.tokenizer_pool_size,
1103
+ self.tokenizer_pool_type,
1104
+ self.tokenizer_pool_extra_config,
1105
+ ),
1106
+ ray_workers_use_nsight=self.ray_workers_use_nsight,
1107
+ distributed_executor_backend=self.distributed_executor_backend,
1108
+ worker_cls=self.worker_cls,
1109
+ )
1110
+
1111
+ max_model_len = model_config.max_model_len
1112
+ use_long_context = max_model_len > 32768
1113
+ if self.enable_chunked_prefill is None:
1114
+ # If not explicitly set, enable chunked prefill by default for
1115
+ # long context (> 32K) models. This is to avoid OOM errors in the
1116
+ # initial memory profiling phase.
1117
+
1118
+ # For multimodal models, chunked prefill is disabled by default in
1119
+ # V0, but enabled by design in V1
1120
+ if model_config.is_multimodal_model:
1121
+ self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
1122
+
1123
+ elif use_long_context:
1124
+ is_gpu = device_config.device_type == "cuda"
1125
+ use_sliding_window = (model_config.get_sliding_window()
1126
+ is not None)
1127
+ use_spec_decode = self.speculative_model is not None
1128
+ from vllm.platforms import current_platform
1129
+ if (is_gpu and not use_sliding_window and not use_spec_decode
1130
+ and not self.enable_lora
1131
+ and not self.enable_prompt_adapter
1132
+ and model_config.runner_type != "pooling"
1133
+ and not current_platform.is_rocm()):
1134
+ self.enable_chunked_prefill = True
1135
+ logger.warning(
1136
+ "Chunked prefill is enabled by default for models with "
1137
+ "max_model_len > 32K. Currently, chunked prefill might "
1138
+ "not work with some features or models. If you "
1139
+ "encounter any issues, please disable chunked prefill "
1140
+ "by setting --enable-chunked-prefill=False.")
1141
+ if self.enable_chunked_prefill is None:
1142
+ self.enable_chunked_prefill = False
1143
+
1144
+ if not self.enable_chunked_prefill and use_long_context:
1145
+ logger.warning(
1146
+ "The model has a long context length (%s). This may cause OOM "
1147
+ "errors during the initial memory profiling phase, or result "
1148
+ "in low performance due to small KV cache space. Consider "
1149
+ "setting --max-model-len to a smaller value.", max_model_len)
1150
+ elif (self.enable_chunked_prefill
1151
+ and model_config.runner_type == "pooling"):
1152
+ msg = "Chunked prefill is not supported for pooling models"
1153
+ raise ValueError(msg)
1154
+
1155
+
1156
+ speculative_config = SpeculativeConfig.maybe_create_spec_config(
1157
+ target_model_config=model_config,
1158
+ target_parallel_config=parallel_config,
1159
+ target_dtype=self.dtype,
1160
+ speculative_model=self.speculative_model,
1161
+ speculative_model_quantization = \
1162
+ self.speculative_model_quantization,
1163
+ speculative_draft_tensor_parallel_size = \
1164
+ self.speculative_draft_tensor_parallel_size,
1165
+ num_speculative_tokens=self.num_speculative_tokens,
1166
+ speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1167
+ speculative_disable_by_batch_size=self.
1168
+ speculative_disable_by_batch_size,
1169
+ speculative_max_model_len=self.speculative_max_model_len,
1170
+ enable_chunked_prefill=self.enable_chunked_prefill,
1171
+ disable_log_stats=self.disable_log_stats,
1172
+ ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
1173
+ ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1174
+ draft_token_acceptance_method=\
1175
+ self.spec_decoding_acceptance_method,
1176
+ typical_acceptance_sampler_posterior_threshold=self.
1177
+ typical_acceptance_sampler_posterior_threshold,
1178
+ typical_acceptance_sampler_posterior_alpha=self.
1179
+ typical_acceptance_sampler_posterior_alpha,
1180
+ disable_logprobs=self.disable_logprobs_during_spec_decoding,
1181
+ )
1182
+
1183
+ # Reminder: Please update docs/source/features/compatibility_matrix.md
1184
+ # If the feature combo become valid
1185
+ if self.num_scheduler_steps > 1:
1186
+ if speculative_config is not None:
1187
+ raise ValueError("Speculative decoding is not supported with "
1188
+ "multi-step (--num-scheduler-steps > 1)")
1189
+ if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
1190
+ raise ValueError("Multi-Step Chunked-Prefill is not supported "
1191
+ "for pipeline-parallel-size > 1")
1192
+ from vllm.platforms import current_platform
1193
+ if current_platform.is_cpu():
1194
+ logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
1195
+ "currently not supported for CPUs and has been "
1196
+ "disabled.")
1197
+ self.num_scheduler_steps = 1
1198
+
1199
+ # make sure num_lookahead_slots is set the higher value depending on
1200
+ # if we are using speculative decoding or multi-step
1201
+ num_lookahead_slots = max(self.num_lookahead_slots,
1202
+ self.num_scheduler_steps - 1)
1203
+ num_lookahead_slots = num_lookahead_slots \
1204
+ if speculative_config is None \
1205
+ else speculative_config.num_lookahead_slots
1206
+
1207
+ if not self.use_v2_block_manager:
1208
+ logger.warning(
1209
+ "[DEPRECATED] Block manager v1 has been removed, "
1210
+ "and setting --use-v2-block-manager to True or False has "
1211
+ "no effect on vLLM behavior. Please remove "
1212
+ "--use-v2-block-manager in your engine argument. "
1213
+ "If your use case is not supported by "
1214
+ "SelfAttnBlockSpaceManager (i.e. block manager v2),"
1215
+ " please file an issue with detailed information.")
1216
+
1217
+ scheduler_config = SchedulerConfig(
1218
+ runner_type=model_config.runner_type,
1219
+ max_num_batched_tokens=self.max_num_batched_tokens,
1220
+ max_num_seqs=self.max_num_seqs,
1221
+ max_model_len=model_config.max_model_len,
1222
+ num_lookahead_slots=num_lookahead_slots,
1223
+ delay_factor=self.scheduler_delay_factor,
1224
+ enable_chunked_prefill=self.enable_chunked_prefill,
1225
+ is_multimodal_model=model_config.is_multimodal_model,
1226
+ preemption_mode=self.preemption_mode,
1227
+ num_scheduler_steps=self.num_scheduler_steps,
1228
+ multi_step_stream_outputs=self.multi_step_stream_outputs,
1229
+ send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
1230
+ and parallel_config.use_ray),
1231
+ policy=self.scheduling_policy)
1232
+ lora_config = LoRAConfig(
1233
+ bias_enabled=self.enable_lora_bias,
1234
+ max_lora_rank=self.max_lora_rank,
1235
+ max_loras=self.max_loras,
1236
+ fully_sharded_loras=self.fully_sharded_loras,
1237
+ lora_extra_vocab_size=self.lora_extra_vocab_size,
1238
+ long_lora_scaling_factors=self.long_lora_scaling_factors,
1239
+ lora_dtype=self.lora_dtype,
1240
+ max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1241
+ and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1242
+
1243
+ if self.qlora_adapter_name_or_path is not None and \
1244
+ self.qlora_adapter_name_or_path != "":
1245
+ if self.model_loader_extra_config is None:
1246
+ self.model_loader_extra_config = {}
1247
+ self.model_loader_extra_config[
1248
+ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
1249
+
1250
+ load_config = self.create_load_config()
1251
+
1252
+ prompt_adapter_config = PromptAdapterConfig(
1253
+ max_prompt_adapters=self.max_prompt_adapters,
1254
+ max_prompt_adapter_token=self.max_prompt_adapter_token) \
1255
+ if self.enable_prompt_adapter else None
1256
+
1257
+ decoding_config = DecodingConfig(
1258
+ guided_decoding_backend=self.guided_decoding_backend)
1259
+
1260
+ detailed_trace_modules = []
1261
+ if self.collect_detailed_traces is not None:
1262
+ detailed_trace_modules = self.collect_detailed_traces.split(",")
1263
+ for m in detailed_trace_modules:
1264
+ if m not in ALLOWED_DETAILED_TRACE_MODULES:
1265
+ raise ValueError(
1266
+ f"Invalid module {m} in collect_detailed_traces. "
1267
+ f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1268
+ observability_config = ObservabilityConfig(
1269
+ otlp_traces_endpoint=self.otlp_traces_endpoint,
1270
+ collect_model_forward_time="model" in detailed_trace_modules
1271
+ or "all" in detailed_trace_modules,
1272
+ collect_model_execute_time="worker" in detailed_trace_modules
1273
+ or "all" in detailed_trace_modules,
1274
+ )
1275
+
1276
+ config = VllmConfig(
1277
+ model_config=model_config,
1278
+ cache_config=cache_config,
1279
+ parallel_config=parallel_config,
1280
+ scheduler_config=scheduler_config,
1281
+ device_config=device_config,
1282
+ lora_config=lora_config,
1283
+ speculative_config=speculative_config,
1284
+ load_config=load_config,
1285
+ decoding_config=decoding_config,
1286
+ observability_config=observability_config,
1287
+ prompt_adapter_config=prompt_adapter_config,
1288
+ compilation_config=self.compilation_config,
1289
+ kv_transfer_config=self.kv_transfer_config,
1290
+ )
1291
+
1292
+ if envs.VLLM_USE_V1:
1293
+ self._override_v1_engine_config(config)
1294
+ return config
1295
+
1296
+ def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
1297
+ """
1298
+ Override the EngineArgs's args based on the usage context for V1.
1299
+ """
1300
+ assert envs.VLLM_USE_V1, "V1 is not enabled"
1301
+
1302
+ # V1 always uses chunked prefills.
1303
+ self.enable_chunked_prefill = True
1304
+ # When no user override, set the default values based on the usage
1305
+ # context.
1306
+ # Use different default values for different hardware.
1307
+ from vllm.platforms import current_platform
1308
+ device_name = current_platform.get_device_name().lower()
1309
+ if "h100" in device_name or "h200" in device_name:
1310
+ # For H100 and H200, we use larger default values.
1311
+ default_max_num_batched_tokens = {
1312
+ UsageContext.LLM_CLASS: 16384,
1313
+ UsageContext.OPENAI_API_SERVER: 8192,
1314
+ }
1315
+ else:
1316
+ # TODO(woosuk): Tune the default values for other hardware.
1317
+ default_max_num_batched_tokens = {
1318
+ UsageContext.LLM_CLASS: 8192,
1319
+ UsageContext.OPENAI_API_SERVER: 2048,
1320
+ }
1321
+
1322
+ if (self.max_num_batched_tokens is None
1323
+ and usage_context in default_max_num_batched_tokens):
1324
+ self.max_num_batched_tokens = default_max_num_batched_tokens[
1325
+ usage_context]
1326
+ logger.warning(
1327
+ "Setting max_num_batched_tokens to %d for %s usage context.",
1328
+ self.max_num_batched_tokens, usage_context.value)
1329
+
1330
+ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
1331
+ """
1332
+ Override the EngineConfig's configs based on the usage context for V1.
1333
+ """
1334
+ assert envs.VLLM_USE_V1, "V1 is not enabled"
1335
+
1336
+
1337
+ @dataclass
1338
+ class AsyncEngineArgs(EngineArgs):
1339
+ """Arguments for asynchronous vLLM engine."""
1340
+ disable_log_requests: bool = False
1341
+
1342
+ @staticmethod
1343
+ def add_cli_args(parser: FlexibleArgumentParser,
1344
+ async_args_only: bool = False) -> FlexibleArgumentParser:
1345
+ if not async_args_only:
1346
+ parser = EngineArgs.add_cli_args(parser)
1347
+ parser.add_argument('--disable-log-requests',
1348
+ action='store_true',
1349
+ help='Disable logging requests.')
1350
+ return parser
1351
+
1352
+
1353
+ # These functions are used by sphinx to build the documentation
1354
+ def _engine_args_parser():
1355
+ return EngineArgs.add_cli_args(FlexibleArgumentParser())
1356
+
1357
+
1358
+ def _async_engine_args_parser():
1359
+ return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1360
+ async_args_only=True)
.venv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import copy
5
+ import time
6
+ import weakref
7
+ from functools import partial
8
+ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
9
+ List, Mapping, Optional, Set, Tuple, Type, Union, overload)
10
+ from weakref import ReferenceType
11
+
12
+ from typing_extensions import deprecated
13
+
14
+ import vllm.envs as envs
15
+ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
16
+ ParallelConfig, SchedulerConfig, VllmConfig)
17
+ from vllm.core.scheduler import SchedulerOutputs
18
+ from vllm.engine.arg_utils import AsyncEngineArgs
19
+ from vllm.engine.async_timeout import asyncio_timeout
20
+ from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
21
+ from vllm.engine.metrics_types import StatLoggerBase
22
+ from vllm.engine.protocol import EngineClient
23
+ from vllm.executor.executor_base import ExecutorBase
24
+ from vllm.inputs import PromptType
25
+ from vllm.inputs.preprocess import InputPreprocessor
26
+ from vllm.logger import init_logger
27
+ from vllm.lora.request import LoRARequest
28
+ from vllm.model_executor.guided_decoding import (
29
+ get_guided_decoding_logits_processor)
30
+ from vllm.model_executor.layers.sampler import SamplerOutput
31
+ from vllm.outputs import PoolingRequestOutput, RequestOutput
32
+ from vllm.pooling_params import PoolingParams
33
+ from vllm.prompt_adapter.request import PromptAdapterRequest
34
+ from vllm.sampling_params import SamplingParams
35
+ from vllm.sequence import ExecuteModelRequest
36
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
37
+ from vllm.usage.usage_lib import UsageContext
38
+ from vllm.utils import deprecate_kwargs, weak_bind
39
+
40
+ logger = init_logger(__name__)
41
+ ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
42
+
43
+
44
+ class AsyncEngineDeadError(RuntimeError):
45
+ pass
46
+
47
+
48
+ def _log_task_completion(task: asyncio.Task,
49
+ error_callback: Callable[[Exception], None]) -> None:
50
+ """This function is only intended for the `engine.run_engine_loop()` task.
51
+
52
+ In particular, that task runs a `while True` loop that can only exit if
53
+ there is an exception.
54
+ """
55
+
56
+ exception = None
57
+ try:
58
+ return_value = task.result()
59
+ raise AssertionError(
60
+ f"The engine background task should never finish without an "
61
+ f"exception. {return_value}")
62
+ except asyncio.exceptions.CancelledError:
63
+ # We assume that if the task is cancelled, we are gracefully shutting
64
+ # down. This should only happen on program exit.
65
+ logger.info("Engine is gracefully shutting down.")
66
+ except Exception as e:
67
+ exception = e
68
+ logger.error("Engine background task failed", exc_info=e)
69
+ error_callback(exception)
70
+ raise AsyncEngineDeadError(
71
+ "Task finished unexpectedly. This should never happen! "
72
+ "Please open an issue on Github. See stack trace above for the "
73
+ "actual cause.") from e
74
+
75
+
76
+ STOP_ITERATION = Exception() # Sentinel
77
+
78
+
79
+ class AsyncStream:
80
+ """A stream of RequestOutputs or PoolingRequestOutputs for a request
81
+ that can be iterated over asynchronously via an async generator."""
82
+
83
+ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
84
+ self.request_id = request_id
85
+ self._cancel = cancel
86
+ self._queue: asyncio.Queue = asyncio.Queue()
87
+ self._finished = False
88
+
89
+ def put(self, item: Union[RequestOutput, PoolingRequestOutput,
90
+ Exception]) -> None:
91
+ if not self._finished:
92
+ self._queue.put_nowait(item)
93
+
94
+ def finish(
95
+ self,
96
+ exception: Optional[Union[BaseException, Type[BaseException]]] = None,
97
+ ) -> None:
98
+ if not self._finished:
99
+ self._finished = True
100
+ self._queue.put_nowait(
101
+ exception if self._is_raisable(exception) else STOP_ITERATION)
102
+
103
+ @property
104
+ def finished(self) -> bool:
105
+ return self._finished
106
+
107
+ async def generator(
108
+ self
109
+ ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
110
+ try:
111
+ while True:
112
+ result = await self._queue.get()
113
+ if self._is_raisable(result):
114
+ if result == STOP_ITERATION:
115
+ return
116
+ raise result
117
+ yield result
118
+ except GeneratorExit:
119
+ self._cancel(self.request_id)
120
+ raise asyncio.CancelledError from None
121
+
122
+ @staticmethod
123
+ def _is_raisable(value: Any):
124
+ return isinstance(value, BaseException) or \
125
+ (isinstance(value, type) and \
126
+ issubclass(value, BaseException))
127
+
128
+
129
+ class RequestTracker:
130
+ """Synchronous abstraction for tracking requests."""
131
+
132
+ def __init__(self) -> None:
133
+ self._request_streams: Dict[str, AsyncStream] = {}
134
+ self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
135
+ self._new_requests: asyncio.Queue[Tuple[AsyncStream,
136
+ dict]] = asyncio.Queue()
137
+ self.new_requests_event = asyncio.Event()
138
+
139
+ def __contains__(self, item):
140
+ return item in self._request_streams
141
+
142
+ def __len__(self) -> int:
143
+ return len(self._request_streams)
144
+
145
+ def propagate_exception(self,
146
+ exc: Exception,
147
+ request_id: Optional[str] = None) -> None:
148
+ """Propagate an exception to request streams
149
+ (all if request_id is None)."""
150
+ if request_id is not None:
151
+ self.abort_request(request_id, exception=exc)
152
+ else:
153
+ # NB: tuple() used here because self.abort_request pops the stream
154
+ # out of self._request_streams, so we can't iterate on it directly
155
+ for rid in tuple(self._request_streams.keys()):
156
+ self.abort_request(rid, exception=exc)
157
+
158
+ def process_request_output(self,
159
+ request_output: Union[RequestOutput,
160
+ PoolingRequestOutput],
161
+ *,
162
+ verbose: bool = False) -> None:
163
+ """Process a request output from the engine."""
164
+ request_id = request_output.request_id
165
+ finished = request_output.finished
166
+
167
+ if finished:
168
+ stream = self._request_streams.pop(request_id, None)
169
+ else:
170
+ stream = self._request_streams.get(request_id)
171
+ # Guard against a KeyError which can occur if the request was aborted
172
+ # while the output was generated
173
+ if stream is not None:
174
+ stream.put(request_output)
175
+ if finished:
176
+ stream.finish()
177
+
178
+ if verbose and finished:
179
+ logger.info("Finished request %s.", request_id)
180
+
181
+ def process_exception(self,
182
+ request_id: str,
183
+ exception: BaseException,
184
+ *,
185
+ verbose: bool = False) -> None:
186
+ """Propagate an exception from the engine."""
187
+ if verbose:
188
+ logger.info("Finished request %s.", request_id)
189
+ self.abort_request(request_id, exception=exception)
190
+
191
+ def add_request(self,
192
+ request_id: str,
193
+ *,
194
+ verbose: bool = False,
195
+ **engine_add_request_kwargs) -> AsyncStream:
196
+ """Add a request to be sent to the engine on the next background
197
+ loop iteration."""
198
+ if request_id in self._request_streams:
199
+ raise KeyError(f"Request {request_id} already exists.")
200
+
201
+ abort_request = partial(self.abort_request, verbose=verbose)
202
+ stream = AsyncStream(request_id, abort_request)
203
+ self._new_requests.put_nowait((stream, {
204
+ "request_id": request_id,
205
+ **engine_add_request_kwargs
206
+ }))
207
+
208
+ self.new_requests_event.set()
209
+
210
+ if verbose:
211
+ logger.info("Added request %s.", request_id)
212
+
213
+ return stream
214
+
215
+ def abort_request(self,
216
+ request_id: str,
217
+ *,
218
+ exception: Optional[Union[BaseException,
219
+ Type[BaseException]]] = None,
220
+ verbose: bool = False) -> None:
221
+ """Abort a request during next background loop iteration."""
222
+ if verbose:
223
+ logger.info("Aborted request %s.", request_id)
224
+
225
+ self._aborted_requests.put_nowait(request_id)
226
+
227
+ stream = self._request_streams.pop(request_id, None)
228
+ if stream is not None:
229
+ stream.finish(exception=exception)
230
+
231
+ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
232
+ """Get the new requests and finished requests to be
233
+ sent to the engine."""
234
+ new_requests: List[Dict] = []
235
+ finished_requests: Set[str] = set()
236
+
237
+ while not self._aborted_requests.empty():
238
+ request_id = self._aborted_requests.get_nowait()
239
+ finished_requests.add(request_id)
240
+
241
+ while not self._new_requests.empty():
242
+ stream, new_request = self._new_requests.get_nowait()
243
+ request_id = stream.request_id
244
+ if request_id in finished_requests:
245
+ # The request has already been aborted.
246
+ stream.finish(asyncio.CancelledError)
247
+ finished_requests.discard(request_id)
248
+ else:
249
+ self._request_streams[request_id] = stream
250
+ new_requests.append(new_request)
251
+
252
+ return new_requests, finished_requests
253
+
254
+ async def wait_for_new_requests(self):
255
+ if not self.has_new_requests():
256
+ await self.new_requests_event.wait()
257
+ self.new_requests_event.clear()
258
+
259
+ def has_new_requests(self):
260
+ return not self._new_requests.empty()
261
+
262
+
263
+ class _AsyncLLMEngine(LLMEngine):
264
+ """Extension of LLMEngine to add async methods."""
265
+
266
+ def __init__(self, *args, **kwargs):
267
+ super().__init__(*args, **kwargs)
268
+
269
+ async def step_async(
270
+ self, virtual_engine: int
271
+ ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
272
+ """Performs one decoding iteration and returns newly generated results.
273
+ The workers are ran asynchronously if possible.
274
+
275
+ This function performs one decoding iteration of the engine. It first
276
+ schedules the sequences to be executed in the next iteration and the
277
+ token blocks to be swapped in/out/copy. Then, it executes the model
278
+ and updates the scheduler with the model outputs. Finally, it decodes
279
+ the sequences and returns the newly generated results.
280
+ """
281
+ # these are cached outputs from previous iterations. None if on first
282
+ # iteration
283
+ cached_outputs = self.cached_scheduler_outputs[virtual_engine]
284
+ seq_group_metadata_list = cached_outputs.seq_group_metadata_list
285
+ scheduler_outputs = cached_outputs.scheduler_outputs
286
+ allow_async_output_proc = cached_outputs.allow_async_output_proc
287
+
288
+ ctx = self.scheduler_contexts[virtual_engine]
289
+
290
+ # Clear outputs for each new scheduler iteration
291
+ ctx.request_outputs.clear()
292
+
293
+ # skip the scheduler if there are any remaining steps in the seq groups.
294
+ # This ensures that the scheduler is only called again when the current
295
+ # batch has completed.
296
+ if not self._has_remaining_steps(seq_group_metadata_list):
297
+
298
+ # Schedule iteration
299
+ (seq_group_metadata_list, scheduler_outputs,
300
+ allow_async_output_proc
301
+ ) = self.scheduler[virtual_engine].schedule()
302
+
303
+ ctx.seq_group_metadata_list = seq_group_metadata_list
304
+ ctx.scheduler_outputs = scheduler_outputs
305
+
306
+ finished_requests_ids = self.scheduler[
307
+ virtual_engine].get_and_reset_finished_requests_ids()
308
+
309
+ # Maybe switch from async mode to sync mode
310
+ if not allow_async_output_proc and len(ctx.output_queue) > 0:
311
+ self._process_model_outputs(ctx=ctx)
312
+
313
+ if (self.scheduler_config.is_multi_step
314
+ and scheduler_outputs.num_lookahead_slots > 0):
315
+ # cache the scheduler outputs for the next iteration if we have
316
+ # lookahead slots
317
+ self._cache_scheduler_outputs_for_multi_step(
318
+ virtual_engine, seq_group_metadata_list, scheduler_outputs,
319
+ allow_async_output_proc)
320
+ else:
321
+ finished_requests_ids = list()
322
+
323
+ assert seq_group_metadata_list is not None
324
+ assert scheduler_outputs is not None
325
+
326
+ if not scheduler_outputs.is_empty():
327
+
328
+ # Check if we have a cached last_output from the previous iteration.
329
+ # For supporting PP this is probably the best way to pass the
330
+ # sampled_token_ids, as a separate broadcast over all the PP stages
331
+ # will cause one virtual engine's microbatch to block the pipeline.
332
+ last_sampled_token_ids = \
333
+ self._get_last_sampled_token_ids(virtual_engine)
334
+
335
+ execute_model_req = ExecuteModelRequest(
336
+ seq_group_metadata_list=seq_group_metadata_list,
337
+ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
338
+ blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
339
+ blocks_to_copy=scheduler_outputs.blocks_to_copy,
340
+ virtual_engine=virtual_engine,
341
+ num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
342
+ running_queue_size=scheduler_outputs.running_queue_size,
343
+ finished_requests_ids=finished_requests_ids,
344
+ # We use ExecuteModelRequest to pass the last sampled_token_ids
345
+ # to each of the non-last PP stages for in-place prepare_input.
346
+ last_sampled_token_ids=last_sampled_token_ids)
347
+
348
+ if allow_async_output_proc:
349
+ execute_model_req.async_callback = self.async_callbacks[
350
+ virtual_engine]
351
+
352
+ # Execute the model.
353
+ outputs = await self.model_executor.execute_model_async(
354
+ execute_model_req)
355
+
356
+ # we need to do this here so that last step's sampled_token_ids can
357
+ # be passed to the next iteration for PP.
358
+ if self.scheduler_config.is_multi_step:
359
+ self._update_cached_scheduler_output(virtual_engine, outputs)
360
+ else:
361
+ if len(ctx.output_queue) > 0:
362
+ self._process_model_outputs(ctx=ctx)
363
+ outputs = []
364
+
365
+ # Finish the current step for all the sequence groups.
366
+ if self.scheduler_config.is_multi_step:
367
+ for seq_group in seq_group_metadata_list:
368
+ seq_group.finish_step()
369
+
370
+ if not self._has_remaining_steps(seq_group_metadata_list):
371
+ # Clear the cache if we have finished all the steps
372
+ if self.scheduler_config.is_multi_step:
373
+ self.cached_scheduler_outputs[
374
+ virtual_engine] = SchedulerOutputState()
375
+
376
+ # is_first_step_output is True only when the num_steps of all
377
+ # the sequences are 1. When the num_steps > 1,
378
+ # multi_step_model_runner does the first-step output append.
379
+ is_first_step_output: bool = False if not seq_group_metadata_list \
380
+ else seq_group_metadata_list[0].state.num_steps == 1
381
+
382
+ ctx.append_output(outputs=outputs,
383
+ seq_group_metadata_list=seq_group_metadata_list,
384
+ scheduler_outputs=scheduler_outputs,
385
+ is_async=allow_async_output_proc,
386
+ is_last_step=True,
387
+ is_first_step_output=is_first_step_output)
388
+
389
+ if outputs and allow_async_output_proc:
390
+ assert len(
391
+ outputs
392
+ ) == 1, "Async postprocessor expects only a single output set"
393
+ self._advance_to_next_step(
394
+ outputs[0], seq_group_metadata_list,
395
+ scheduler_outputs.scheduled_seq_groups)
396
+
397
+ if not allow_async_output_proc:
398
+ self._process_model_outputs(ctx=ctx)
399
+
400
+ # Log stats.
401
+ self.do_log_stats(scheduler_outputs, outputs)
402
+
403
+ # Tracing
404
+ self.do_tracing(scheduler_outputs)
405
+
406
+ else:
407
+ # Multi-step case
408
+ return ctx.request_outputs
409
+
410
+ if not self.has_unfinished_requests():
411
+ # Drain async postprocessor (if exists)
412
+ if len(ctx.output_queue) > 0:
413
+ self._process_model_outputs(ctx=ctx)
414
+ assert len(ctx.output_queue) == 0
415
+
416
+ return ctx.request_outputs
417
+
418
+ async def stop_remote_worker_execution_loop_async(self) -> None:
419
+ """Stop the remote worker execution loop."""
420
+ await self.model_executor.stop_remote_worker_execution_loop_async()
421
+
422
+ async def get_tokenizer_async(self,
423
+ lora_request: Optional[LoRARequest] = None
424
+ ) -> AnyTokenizer:
425
+ return await (
426
+ self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
427
+
428
+ @overload
429
+ @deprecated("'inputs' will be renamed to 'prompt")
430
+ async def add_request_async(
431
+ self,
432
+ request_id: str,
433
+ *,
434
+ inputs: PromptType,
435
+ params: Union[SamplingParams, PoolingParams],
436
+ arrival_time: Optional[float] = None,
437
+ lora_request: Optional[LoRARequest] = None,
438
+ trace_headers: Optional[Mapping[str, str]] = None,
439
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
440
+ priority: int = 0,
441
+ ) -> None:
442
+ ...
443
+
444
+ @overload
445
+ async def add_request_async(
446
+ self,
447
+ request_id: str,
448
+ prompt: PromptType,
449
+ params: Union[SamplingParams, PoolingParams],
450
+ arrival_time: Optional[float] = None,
451
+ lora_request: Optional[LoRARequest] = None,
452
+ trace_headers: Optional[Mapping[str, str]] = None,
453
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
454
+ priority: int = 0,
455
+ ) -> None:
456
+ ...
457
+
458
+ @deprecate_kwargs(
459
+ "inputs",
460
+ additional_message="Please use the 'prompt' parameter instead.",
461
+ )
462
+ async def add_request_async(
463
+ self,
464
+ request_id: str,
465
+ prompt: Optional[PromptType] = None,
466
+ params: Optional[Union[SamplingParams, PoolingParams]] = None,
467
+ arrival_time: Optional[float] = None,
468
+ lora_request: Optional[LoRARequest] = None,
469
+ trace_headers: Optional[Mapping[str, str]] = None,
470
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
471
+ priority: int = 0,
472
+ *,
473
+ inputs: Optional[PromptType] = None, # DEPRECATED
474
+ ) -> None:
475
+ """Async version of :meth:`add_request`."""
476
+ if inputs is not None:
477
+ prompt = inputs
478
+ assert prompt is not None and params is not None
479
+
480
+ if lora_request is not None and not self.lora_config:
481
+ raise ValueError(f"Got lora_request {lora_request} but LoRA is "
482
+ "not enabled!")
483
+ if priority != 0 and not self.scheduler_config.policy == "priority":
484
+ raise ValueError(f"Got priority {priority} but "
485
+ "Priority scheduling is not enabled.")
486
+ if arrival_time is None:
487
+ arrival_time = time.time()
488
+
489
+ if self.tokenizer is not None:
490
+ tokenizer = await self.get_tokenizer_async(lora_request)
491
+ self._validate_token_prompt(prompt, tokenizer=tokenizer)
492
+
493
+ preprocessed_inputs = await self.input_preprocessor.preprocess_async(
494
+ prompt,
495
+ request_id=request_id,
496
+ lora_request=lora_request,
497
+ prompt_adapter_request=prompt_adapter_request,
498
+ )
499
+ processed_inputs = self.input_processor(preprocessed_inputs)
500
+
501
+ if isinstance(params, SamplingParams) and \
502
+ params.guided_decoding is not None:
503
+ # Guided decoding has an async implementation for building logits
504
+ # processors in a separate threadpool.
505
+ # We want to invoke that here instead of using the blocking
506
+ # implementation in the LLMEngine
507
+ params = await build_guided_decoding_logits_processor_async(
508
+ sampling_params=params,
509
+ tokenizer=await self.get_tokenizer_async(lora_request),
510
+ default_guided_backend=self.decoding_config.
511
+ guided_decoding_backend,
512
+ model_config=self.model_config)
513
+
514
+ self._add_processed_request(
515
+ request_id=request_id,
516
+ processed_inputs=processed_inputs,
517
+ params=params,
518
+ arrival_time=arrival_time,
519
+ lora_request=lora_request,
520
+ prompt_adapter_request=prompt_adapter_request,
521
+ trace_headers=trace_headers,
522
+ priority=priority,
523
+ )
524
+
525
+ async def check_health_async(self) -> None:
526
+ if self.tokenizer:
527
+ self.tokenizer.check_health()
528
+ self.model_executor.check_health()
529
+
530
+
531
+ async def build_guided_decoding_logits_processor_async(
532
+ sampling_params: SamplingParams, tokenizer: AnyTokenizer,
533
+ default_guided_backend: str,
534
+ model_config: ModelConfig) -> SamplingParams:
535
+ """Constructs logits processors based on the guided_decoding,
536
+ logits_bias, and allowed_token_ids fields in sampling_params. Deletes
537
+ those fields and adds the constructed logits processors to the
538
+ logits_processors field. Modifies sampling params in-place and returns
539
+ the modified sampling params."""
540
+ if sampling_params.guided_decoding is None:
541
+ return sampling_params
542
+
543
+ # Defensively copy sampling params since guided decoding logits
544
+ # processors can have different state for each request
545
+ sampling_params = copy.copy(sampling_params)
546
+ guided_decoding = sampling_params.guided_decoding
547
+
548
+ logger.debug("Building guided decoding logits processor. "
549
+ "Params: %s", guided_decoding)
550
+
551
+ guided_decoding.backend = guided_decoding.backend or default_guided_backend
552
+
553
+ processor = await get_guided_decoding_logits_processor(
554
+ guided_params=guided_decoding,
555
+ tokenizer=tokenizer,
556
+ model_config=model_config)
557
+
558
+ if processor:
559
+ if sampling_params.logits_processors is None:
560
+ sampling_params.logits_processors = []
561
+ sampling_params.logits_processors.append(processor)
562
+
563
+ # Unset guided decoding params after constructing the lp from them
564
+ sampling_params.guided_decoding = None
565
+
566
+ return sampling_params
567
+
568
+
569
+ class AsyncLLMEngine(EngineClient):
570
+ """An asynchronous wrapper for :class:`LLMEngine`.
571
+
572
+ This class is used to wrap the :class:`LLMEngine` class to make it
573
+ asynchronous. It uses asyncio to create a background loop that keeps
574
+ processing incoming requests. The :class:`LLMEngine` is kicked by the
575
+ generate method when there are requests in the waiting queue. The generate
576
+ method yields the outputs from the :class:`LLMEngine` to the caller.
577
+
578
+ Args:
579
+ log_requests: Whether to log the requests.
580
+ start_engine_loop: If True, the background task to run the engine
581
+ will be automatically started in the generate call.
582
+ *args: Arguments for :class:`LLMEngine`.
583
+ **kwargs: Arguments for :class:`LLMEngine`.
584
+ """
585
+
586
+ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
587
+
588
+ def __init__(self,
589
+ *args,
590
+ log_requests: bool = True,
591
+ start_engine_loop: bool = True,
592
+ **kwargs) -> None:
593
+ self.log_requests = log_requests
594
+ self.engine = self._engine_class(*args, **kwargs)
595
+
596
+ # This ensures quick processing of request outputs
597
+ # so the append to asyncio queues is not delayed,
598
+ # especially for multi-step.
599
+ self.use_process_request_outputs_callback = (
600
+ self.engine.model_config.use_async_output_proc)
601
+
602
+ if self.use_process_request_outputs_callback:
603
+ self.engine.process_request_outputs_callback = \
604
+ weak_bind(self.process_request_outputs)
605
+
606
+ self.background_loop: Optional[asyncio.Future] = None
607
+ # We need to keep a reference to unshielded
608
+ # task as well to prevent it from being garbage
609
+ # collected
610
+ self._background_loop_unshielded: Optional[asyncio.Task] = None
611
+ self.start_engine_loop = start_engine_loop
612
+ self._errored_with: Optional[BaseException] = None
613
+
614
+ # Lazy initialized fields
615
+ self._request_tracker: RequestTracker
616
+
617
+ def __del__(self):
618
+ if rt := getattr(self, "request_tracker", None):
619
+ # Wake up engine loop so that it will exit cleanly
620
+ rt.new_requests_event.set()
621
+
622
+ @classmethod
623
+ def _get_executor_cls(cls,
624
+ engine_config: VllmConfig) -> Type[ExecutorBase]:
625
+ return LLMEngine._get_executor_cls(engine_config)
626
+
627
+ @classmethod
628
+ def from_engine_args(
629
+ cls,
630
+ engine_args: AsyncEngineArgs,
631
+ engine_config: Optional[VllmConfig] = None,
632
+ start_engine_loop: bool = True,
633
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
634
+ stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
635
+ ) -> "AsyncLLMEngine":
636
+ """Creates an async LLM engine from the engine arguments."""
637
+ # Create the engine configs.
638
+ if engine_config is None:
639
+ engine_config = engine_args.create_engine_config(usage_context)
640
+
641
+ executor_class = cls._get_executor_cls(engine_config)
642
+
643
+ # Create the async LLM engine.
644
+ engine = cls(
645
+ vllm_config=engine_config,
646
+ executor_class=executor_class,
647
+ log_requests=not engine_args.disable_log_requests,
648
+ log_stats=not engine_args.disable_log_stats,
649
+ start_engine_loop=start_engine_loop,
650
+ usage_context=usage_context,
651
+ stat_loggers=stat_loggers,
652
+ )
653
+ return engine
654
+
655
+ @property
656
+ def is_running(self) -> bool:
657
+ return (self.background_loop is not None
658
+ and self._background_loop_unshielded is not None
659
+ and not self._background_loop_unshielded.done())
660
+
661
+ @property
662
+ def is_stopped(self) -> bool:
663
+ return self.errored or (self.background_loop is not None and
664
+ self._background_loop_unshielded is not None
665
+ and self._background_loop_unshielded.done())
666
+
667
+ @property
668
+ def errored(self) -> bool:
669
+ return self._errored_with is not None
670
+
671
+ @property
672
+ def dead_error(self) -> BaseException:
673
+ return AsyncEngineDeadError(
674
+ "Background loop is not running. If it was running, "
675
+ "inspect the output to find the stacktrace of the "
676
+ "error that caused the background loop to stop "
677
+ "(AsyncEngineDeadError).")
678
+
679
+ def set_errored(self, exc: Exception) -> None:
680
+ self._errored_with = exc
681
+
682
+ def _error_callback(self, exc: Exception) -> None:
683
+ self.set_errored(exc)
684
+ self._request_tracker.propagate_exception(exc)
685
+
686
+ async def get_input_preprocessor(self) -> InputPreprocessor:
687
+ return self.engine.input_preprocessor
688
+
689
+ async def get_tokenizer(
690
+ self,
691
+ lora_request: Optional[LoRARequest] = None,
692
+ ) -> AnyTokenizer:
693
+ return await self.engine.get_tokenizer_async(lora_request)
694
+
695
+ def start_background_loop(self) -> None:
696
+ """Start the background loop."""
697
+ if self.errored:
698
+ raise AsyncEngineDeadError(
699
+ "Background loop has errored already.") from self._errored_with
700
+ if self.is_running:
701
+ raise RuntimeError("Background loop is already running.")
702
+ # Initialize the RequestTracker here so it uses the right event loop.
703
+ self._request_tracker = RequestTracker()
704
+
705
+ self._background_loop_unshielded = asyncio.get_event_loop(
706
+ ).create_task(self.run_engine_loop(weakref.ref(self)))
707
+ self._background_loop_unshielded.add_done_callback(
708
+ partial(_log_task_completion, error_callback=self._error_callback))
709
+ self.background_loop = asyncio.shield(self._background_loop_unshielded)
710
+
711
+ def shutdown_background_loop(self) -> None:
712
+ """
713
+ Shut down the background loop.
714
+
715
+ This method needs to be called during cleanup to remove
716
+ references to `self` and properly GC the resources held
717
+ by the async LLM engine (e.g., the executors as well as
718
+ their resources).
719
+ """
720
+ if self._background_loop_unshielded is not None:
721
+ self._background_loop_unshielded.cancel()
722
+ self._background_loop_unshielded = None
723
+ self.background_loop = None
724
+
725
+ async def engine_step(self, virtual_engine: int) -> bool:
726
+ """Kick the engine to process the waiting requests.
727
+
728
+ Returns True if there are in-progress requests."""
729
+
730
+ new_requests, aborted_requests = (
731
+ self._request_tracker.get_new_and_aborted_requests())
732
+
733
+ for new_request in new_requests:
734
+ # Add the request into the vLLM engine's waiting queue.
735
+ try:
736
+ await self.engine.add_request_async(**new_request)
737
+ except ValueError as e:
738
+ # TODO: use a vLLM specific error for failed validation
739
+ self._request_tracker.process_exception(
740
+ new_request["request_id"],
741
+ e,
742
+ verbose=self.log_requests,
743
+ )
744
+
745
+ if aborted_requests:
746
+ await self._engine_abort(aborted_requests)
747
+
748
+ request_outputs = await self.engine.step_async(virtual_engine)
749
+
750
+ # Put the outputs into the corresponding streams.
751
+ # If used as a callback, then already invoked inside
752
+ # LLMEngine's _process_model_outputs
753
+ if not self.use_process_request_outputs_callback:
754
+ all_finished = self.process_request_outputs(request_outputs)
755
+ else:
756
+ # For callback case, we only need to detect when all
757
+ # requests are finished
758
+ all_finished = all(request_output.finished
759
+ for request_output in request_outputs)
760
+
761
+ return not all_finished
762
+
763
+ def process_request_outputs(self, request_outputs) -> bool:
764
+ # Put the outputs into the corresponding streams.
765
+ all_finished = True
766
+ for request_output in request_outputs:
767
+ self._request_tracker.process_request_output(
768
+ request_output, verbose=self.log_requests)
769
+ all_finished = all_finished and request_output.finished
770
+
771
+ return all_finished
772
+
773
+ async def _engine_abort(self, request_ids: Iterable[str]):
774
+ self.engine.abort_request(request_ids)
775
+
776
+ @staticmethod
777
+ async def run_engine_loop(engine_ref: ReferenceType):
778
+ """We use a weakref to the engine so that the running loop
779
+ doesn't prevent the engine being garbage collected."""
780
+ engine: Optional[AsyncLLMEngine] = engine_ref()
781
+ if not engine:
782
+ return
783
+
784
+ pipeline_parallel_size = \
785
+ engine.engine.parallel_config.pipeline_parallel_size
786
+ has_requests_in_progress = [False] * pipeline_parallel_size
787
+ while True:
788
+ if not any(has_requests_in_progress):
789
+ logger.debug("Waiting for new requests...")
790
+ # Stop the execute model loop in parallel workers until there
791
+ # are more requests to process. This avoids waiting
792
+ # indefinitely in torch.distributed ops which may otherwise
793
+ # timeout, and unblocks the RPC thread in the workers so that
794
+ # they can process any other queued control plane messages,
795
+ # such as add/remove lora adapters.
796
+ await engine.engine.stop_remote_worker_execution_loop_async()
797
+ request_tracker = engine._request_tracker
798
+ # Allow engine to be garbage collected while
799
+ # waiting for new requests
800
+ del engine
801
+ await asyncio.sleep(0)
802
+ if engine_ref() is None:
803
+ return
804
+ await request_tracker.wait_for_new_requests()
805
+ engine = engine_ref()
806
+ if not engine:
807
+ return
808
+ logger.debug("Got new requests!")
809
+ requests_in_progress = [
810
+ asyncio.create_task(engine.engine_step(ve))
811
+ for ve in range(pipeline_parallel_size)
812
+ ]
813
+ has_requests_in_progress = [True] * pipeline_parallel_size
814
+
815
+ # Abort if iteration takes too long due to unrecoverable errors
816
+ # (eg. NCCL timeouts).
817
+ try:
818
+ async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
819
+ done, _ = await asyncio.wait(
820
+ requests_in_progress,
821
+ return_when=asyncio.FIRST_COMPLETED)
822
+ for _ in range(pipeline_parallel_size):
823
+ await asyncio.sleep(0)
824
+ for task in done:
825
+ result = task.result()
826
+ virtual_engine = requests_in_progress.index(task)
827
+ has_unfinished_requests = (
828
+ engine.engine.
829
+ has_unfinished_requests_for_virtual_engine(
830
+ virtual_engine))
831
+ if result or has_unfinished_requests:
832
+ requests_in_progress[virtual_engine] = (
833
+ asyncio.create_task(
834
+ engine.engine_step(virtual_engine)))
835
+ has_requests_in_progress[virtual_engine] = True
836
+ else:
837
+ has_requests_in_progress[virtual_engine] = False
838
+ except asyncio.TimeoutError as exc:
839
+ logger.error(
840
+ "Engine iteration timed out. This should never happen!")
841
+ engine.set_errored(exc)
842
+ raise
843
+ await asyncio.sleep(0)
844
+
845
+ # This method does not need to be async, but kept that way
846
+ # for backwards compatibility.
847
+ @overload
848
+ @deprecated("'inputs' will be renamed to 'prompt")
849
+ def add_request(
850
+ self,
851
+ request_id: str,
852
+ *,
853
+ inputs: PromptType,
854
+ params: Union[SamplingParams, PoolingParams],
855
+ arrival_time: Optional[float] = None,
856
+ lora_request: Optional[LoRARequest] = None,
857
+ trace_headers: Optional[Mapping[str, str]] = None,
858
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
859
+ priority: int = 0,
860
+ ) -> Coroutine[None, None, AsyncGenerator[Union[
861
+ RequestOutput, PoolingRequestOutput], None]]:
862
+ ...
863
+
864
+ @overload
865
+ def add_request(
866
+ self,
867
+ request_id: str,
868
+ prompt: PromptType,
869
+ params: Union[SamplingParams, PoolingParams],
870
+ arrival_time: Optional[float] = None,
871
+ lora_request: Optional[LoRARequest] = None,
872
+ trace_headers: Optional[Mapping[str, str]] = None,
873
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
874
+ priority: int = 0,
875
+ ) -> Coroutine[None, None, AsyncGenerator[Union[
876
+ RequestOutput, PoolingRequestOutput], None]]:
877
+ ...
878
+
879
+ @deprecate_kwargs(
880
+ "inputs",
881
+ additional_message="Please use the 'prompt' parameter instead.",
882
+ )
883
+ async def add_request(
884
+ self,
885
+ request_id: str,
886
+ prompt: Optional[PromptType] = None,
887
+ params: Optional[Union[SamplingParams, PoolingParams]] = None,
888
+ arrival_time: Optional[float] = None,
889
+ lora_request: Optional[LoRARequest] = None,
890
+ trace_headers: Optional[Mapping[str, str]] = None,
891
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
892
+ priority: int = 0,
893
+ *,
894
+ inputs: Optional[PromptType] = None, # DEPRECATED
895
+ ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
896
+ if inputs is not None:
897
+ prompt = inputs
898
+ assert prompt is not None and params is not None
899
+
900
+ if not self.is_running:
901
+ if self.start_engine_loop:
902
+ self.start_background_loop()
903
+ else:
904
+ raise AsyncEngineDeadError(
905
+ "Background loop is not running. If it was running, "
906
+ "inspect the output to find the stacktrace of the "
907
+ "error that caused the background loop to stop "
908
+ "(AsyncEngineDeadError).")
909
+
910
+ if (priority != 0
911
+ and not self.engine.scheduler_config.policy == "priority"):
912
+ raise ValueError(f"Got priority {priority} but "
913
+ "Priority scheduling is not enabled.")
914
+
915
+ stream = self._request_tracker.add_request(
916
+ request_id,
917
+ verbose=self.log_requests,
918
+ prompt=prompt,
919
+ params=params,
920
+ arrival_time=arrival_time or time.time(),
921
+ lora_request=lora_request,
922
+ trace_headers=trace_headers,
923
+ prompt_adapter_request=prompt_adapter_request,
924
+ priority=priority,
925
+ )
926
+
927
+ return stream.generator()
928
+
929
+ async def generate(
930
+ self,
931
+ prompt: PromptType,
932
+ sampling_params: SamplingParams,
933
+ request_id: str,
934
+ lora_request: Optional[LoRARequest] = None,
935
+ trace_headers: Optional[Mapping[str, str]] = None,
936
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
937
+ priority: int = 0,
938
+ ) -> AsyncGenerator[RequestOutput, None]:
939
+ """Generate outputs for a request.
940
+
941
+ Generate outputs for a request. This method is a coroutine. It adds the
942
+ request into the waiting queue of the LLMEngine and streams the outputs
943
+ from the LLMEngine to the caller.
944
+
945
+ Args:
946
+ prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
947
+ for more details about the format of each input.
948
+ sampling_params: The sampling parameters of the request.
949
+ request_id: The unique id of the request.
950
+ lora_request: LoRA request to use for generation, if any.
951
+ trace_headers: OpenTelemetry trace headers.
952
+ prompt_adapter_request: Prompt Adapter request to use
953
+ for generation, if any.
954
+ priority: The priority of the request.
955
+ Only applicable with priority scheduling.
956
+
957
+ Yields:
958
+ The output `RequestOutput` objects from the LLMEngine
959
+ for the request.
960
+
961
+ Details:
962
+ - If the engine is not running, start the background loop,
963
+ which iteratively invokes
964
+ :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
965
+ to process the waiting requests.
966
+ - Add the request to the engine's `RequestTracker`.
967
+ On the next background loop, this request will be sent to
968
+ the underlying engine.
969
+ Also, a corresponding `AsyncStream` will be created.
970
+ - Wait for the request outputs from `AsyncStream` and yield them.
971
+
972
+ Example:
973
+ >>> # Please refer to entrypoints/api_server.py for
974
+ >>> # the complete example.
975
+ >>>
976
+ >>> # initialize the engine and the example input
977
+ >>> # note that engine_args here is AsyncEngineArgs instance
978
+ >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
979
+ >>> example_input = {
980
+ >>> "prompt": "What is LLM?",
981
+ >>> "stream": False, # assume the non-streaming case
982
+ >>> "temperature": 0.0,
983
+ >>> "request_id": 0,
984
+ >>> }
985
+ >>>
986
+ >>> # start the generation
987
+ >>> results_generator = engine.generate(
988
+ >>> example_input["prompt"],
989
+ >>> SamplingParams(temperature=example_input["temperature"]),
990
+ >>> example_input["request_id"])
991
+ >>>
992
+ >>> # get the results
993
+ >>> final_output = None
994
+ >>> async for request_output in results_generator:
995
+ >>> if await request.is_disconnected():
996
+ >>> # Abort the request if the client disconnects.
997
+ >>> await engine.abort(request_id)
998
+ >>> # Return or raise an error
999
+ >>> ...
1000
+ >>> final_output = request_output
1001
+ >>>
1002
+ >>> # Process and return the final output
1003
+ >>> ...
1004
+ """
1005
+ try:
1006
+ async for output in await self.add_request(
1007
+ request_id,
1008
+ prompt,
1009
+ sampling_params,
1010
+ lora_request=lora_request,
1011
+ trace_headers=trace_headers,
1012
+ prompt_adapter_request=prompt_adapter_request,
1013
+ priority=priority,
1014
+ ):
1015
+ yield LLMEngine.validate_output(output, RequestOutput)
1016
+ except asyncio.CancelledError:
1017
+ await self.abort(request_id)
1018
+ raise
1019
+
1020
+ async def encode(
1021
+ self,
1022
+ prompt: PromptType,
1023
+ pooling_params: PoolingParams,
1024
+ request_id: str,
1025
+ lora_request: Optional[LoRARequest] = None,
1026
+ trace_headers: Optional[Mapping[str, str]] = None,
1027
+ priority: int = 0,
1028
+ ) -> AsyncGenerator[PoolingRequestOutput, None]:
1029
+ """Generate outputs for a request from a pooling model.
1030
+
1031
+ Generate outputs for a request. This method is a coroutine. It adds the
1032
+ request into the waiting queue of the LLMEngine and streams the outputs
1033
+ from the LLMEngine to the caller.
1034
+
1035
+ Args:
1036
+ prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1037
+ for more details about the format of each input.
1038
+ pooling_params: The pooling parameters of the request.
1039
+ request_id: The unique id of the request.
1040
+ lora_request: LoRA request to use for generation, if any.
1041
+ trace_headers: OpenTelemetry trace headers.
1042
+ priority: The priority of the request.
1043
+ Only applicable with priority scheduling.
1044
+
1045
+ Yields:
1046
+ The output `PoolingRequestOutput` objects from the LLMEngine
1047
+ for the request.
1048
+
1049
+ Details:
1050
+ - If the engine is not running, start the background loop,
1051
+ which iteratively invokes
1052
+ :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
1053
+ to process the waiting requests.
1054
+ - Add the request to the engine's `RequestTracker`.
1055
+ On the next background loop, this request will be sent to
1056
+ the underlying engine.
1057
+ Also, a corresponding `AsyncStream` will be created.
1058
+ - Wait for the request outputs from `AsyncStream` and yield them.
1059
+
1060
+ Example:
1061
+ >>> # Please refer to entrypoints/api_server.py for
1062
+ >>> # the complete example.
1063
+ >>>
1064
+ >>> # initialize the engine and the example input
1065
+ >>> # note that engine_args here is AsyncEngineArgs instance
1066
+ >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
1067
+ >>> example_input = {
1068
+ >>> "input": "What is LLM?",
1069
+ >>> "request_id": 0,
1070
+ >>> }
1071
+ >>>
1072
+ >>> # start the generation
1073
+ >>> results_generator = engine.encode(
1074
+ >>> example_input["input"],
1075
+ >>> PoolingParams(),
1076
+ >>> example_input["request_id"])
1077
+ >>>
1078
+ >>> # get the results
1079
+ >>> final_output = None
1080
+ >>> async for request_output in results_generator:
1081
+ >>> if await request.is_disconnected():
1082
+ >>> # Abort the request if the client disconnects.
1083
+ >>> await engine.abort(request_id)
1084
+ >>> # Return or raise an error
1085
+ >>> ...
1086
+ >>> final_output = request_output
1087
+ >>>
1088
+ >>> # Process and return the final output
1089
+ >>> ...
1090
+ """
1091
+ try:
1092
+ async for output in await self.add_request(
1093
+ request_id,
1094
+ prompt,
1095
+ pooling_params,
1096
+ lora_request=lora_request,
1097
+ trace_headers=trace_headers,
1098
+ priority=priority,
1099
+ ):
1100
+ yield LLMEngine.validate_output(output, PoolingRequestOutput)
1101
+ except asyncio.CancelledError:
1102
+ await self.abort(request_id)
1103
+ raise
1104
+
1105
+ async def abort(self, request_id: str) -> None:
1106
+ """Abort a request.
1107
+
1108
+ Abort a submitted request. If the request is finished or not found,
1109
+ this method will be a no-op.
1110
+
1111
+ Args:
1112
+ request_id: The unique id of the request.
1113
+ """
1114
+ if not self.is_running:
1115
+ raise AsyncEngineDeadError(
1116
+ "Background loop is not running. If it was running, "
1117
+ "inspect the output to find the stacktrace of the "
1118
+ "error that caused the background loop to stop "
1119
+ "(AsyncEngineDeadError).")
1120
+
1121
+ return self._abort(request_id)
1122
+
1123
+ def _abort(self, request_id: str) -> None:
1124
+ """Abort a request.
1125
+
1126
+ Abort a submitted request. If the request is finished or not found,
1127
+ this method will be a no-op.
1128
+
1129
+ Args:
1130
+ request_id: The unique id of the request.
1131
+ """
1132
+ self._request_tracker.abort_request(request_id,
1133
+ exception=asyncio.CancelledError,
1134
+ verbose=self.log_requests)
1135
+
1136
+ async def get_model_config(self) -> ModelConfig:
1137
+ """Get the model configuration of the vLLM engine."""
1138
+ return self.engine.get_model_config()
1139
+
1140
+ async def get_parallel_config(self) -> ParallelConfig:
1141
+ """Get the parallel configuration of the vLLM engine."""
1142
+ return self.engine.get_parallel_config()
1143
+
1144
+ async def get_decoding_config(self) -> DecodingConfig:
1145
+ """Get the decoding configuration of the vLLM engine."""
1146
+ return self.engine.get_decoding_config()
1147
+
1148
+ async def get_scheduler_config(self) -> SchedulerConfig:
1149
+ """Get the scheduling configuration of the vLLM engine."""
1150
+ return self.engine.get_scheduler_config()
1151
+
1152
+ async def get_lora_config(self) -> LoRAConfig:
1153
+ """Get the lora configuration of the vLLM engine."""
1154
+ return self.engine.get_lora_config()
1155
+
1156
+ async def do_log_stats(
1157
+ self,
1158
+ scheduler_outputs: Optional[SchedulerOutputs] = None,
1159
+ model_output: Optional[List[SamplerOutput]] = None) -> None:
1160
+ self.engine.do_log_stats()
1161
+
1162
+ async def check_health(self) -> None:
1163
+ """Raises an error if engine is unhealthy."""
1164
+ t = time.perf_counter()
1165
+ logger.debug("Starting health check...")
1166
+ if self.is_stopped:
1167
+ raise AsyncEngineDeadError("Background loop is stopped.")
1168
+
1169
+ await self.engine.check_health_async()
1170
+ logger.debug("Health check took %fs", time.perf_counter() - t)
1171
+
1172
+ async def is_tracing_enabled(self) -> bool:
1173
+ return self.engine.is_tracing_enabled()
1174
+
1175
+ def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1176
+ self.engine.add_logger(logger_name=logger_name, logger=logger)
1177
+
1178
+ def remove_logger(self, logger_name: str) -> None:
1179
+ self.engine.remove_logger(logger_name=logger_name)
1180
+
1181
+ async def start_profile(self) -> None:
1182
+ self.engine.start_profile()
1183
+
1184
+ async def stop_profile(self) -> None:
1185
+ self.engine.stop_profile()
1186
+
1187
+ async def reset_prefix_cache(self) -> None:
1188
+ self.engine.reset_prefix_cache()
1189
+
1190
+ async def add_lora(self, lora_request: LoRARequest) -> None:
1191
+ self.engine.add_lora(lora_request)
1192
+
1193
+
1194
+ # TODO(v1): Remove this class proxy when V1 goes default.
1195
+ if envs.VLLM_USE_V1:
1196
+ from vllm.v1.engine.async_llm import AsyncLLM
1197
+
1198
+ AsyncLLMEngine = AsyncLLM # type: ignore
.venv/lib/python3.11/site-packages/vllm/engine/async_timeout.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Workaround for https://github.com/python/cpython/issues/86296
4
+ #
5
+ # From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
6
+ # Licensed under the Apache License (Apache-2.0)
7
+
8
+ import asyncio
9
+ import enum
10
+ import sys
11
+ import warnings
12
+ from types import TracebackType
13
+ from typing import Any, Optional, Type
14
+
15
+ if sys.version_info[:2] >= (3, 11):
16
+ from asyncio import timeout as asyncio_timeout
17
+ else:
18
+
19
+ def asyncio_timeout(delay: Optional[float]) -> "Timeout":
20
+ """timeout context manager.
21
+ Useful in cases when you want to apply timeout logic around block
22
+ of code or in cases when asyncio.wait_for is not suitable. For example:
23
+ >>> async with timeout(0.001):
24
+ ... async with aiohttp.get('https://github.com') as r:
25
+ ... await r.text()
26
+ delay - value in seconds or None to disable timeout logic
27
+ """
28
+ loop = asyncio.get_running_loop()
29
+ deadline = loop.time() + delay if delay is not None else None
30
+ return Timeout(deadline, loop)
31
+
32
+ class _State(enum.Enum):
33
+ INIT = "INIT"
34
+ ENTER = "ENTER"
35
+ TIMEOUT = "TIMEOUT"
36
+ EXIT = "EXIT"
37
+
38
+ class Timeout:
39
+ # Internal class, please don't instantiate it directly
40
+ # Use timeout() and timeout_at() public factories instead.
41
+ #
42
+ # Implementation note: `async with timeout()` is preferred
43
+ # over `with timeout()`.
44
+ # While technically the Timeout class implementation
45
+ # doesn't need to be async at all,
46
+ # the `async with` statement explicitly points that
47
+ # the context manager should be used from async function context.
48
+ #
49
+ # This design allows to avoid many silly misusages.
50
+ #
51
+ # TimeoutError is raised immediately when scheduled
52
+ # if the deadline is passed.
53
+ # The purpose is to time out as soon as possible
54
+ # without waiting for the next await expression.
55
+
56
+ __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
57
+
58
+ def __init__(self, deadline: Optional[float],
59
+ loop: asyncio.AbstractEventLoop) -> None:
60
+ self._loop = loop
61
+ self._state = _State.INIT
62
+
63
+ self._timeout_handler = None # type: Optional[asyncio.Handle]
64
+ if deadline is None:
65
+ self._deadline = None # type: Optional[float]
66
+ else:
67
+ self.update(deadline)
68
+
69
+ def __enter__(self) -> "Timeout":
70
+ warnings.warn(
71
+ "with timeout() is deprecated, use async with timeout()",
72
+ DeprecationWarning,
73
+ stacklevel=2,
74
+ )
75
+ self._do_enter()
76
+ return self
77
+
78
+ def __exit__(
79
+ self,
80
+ exc_type: Optional[Type[BaseException]],
81
+ exc_val: Optional[BaseException],
82
+ exc_tb: Optional[TracebackType],
83
+ ) -> Optional[bool]:
84
+ self._do_exit(exc_type)
85
+ return None
86
+
87
+ async def __aenter__(self) -> "Timeout":
88
+ self._do_enter()
89
+ return self
90
+
91
+ async def __aexit__(
92
+ self,
93
+ exc_type: Optional[Type[BaseException]],
94
+ exc_val: Optional[BaseException],
95
+ exc_tb: Optional[TracebackType],
96
+ ) -> Optional[bool]:
97
+ self._do_exit(exc_type)
98
+ return None
99
+
100
+ @property
101
+ def expired(self) -> bool:
102
+ """Is timeout expired during execution?"""
103
+ return self._state == _State.TIMEOUT
104
+
105
+ @property
106
+ def deadline(self) -> Optional[float]:
107
+ return self._deadline
108
+
109
+ def reject(self) -> None:
110
+ """Reject scheduled timeout if any."""
111
+ # cancel is maybe better name but
112
+ # task.cancel() raises CancelledError in asyncio world.
113
+ if self._state not in (_State.INIT, _State.ENTER):
114
+ raise RuntimeError(f"invalid state {self._state.value}")
115
+ self._reject()
116
+
117
+ def _reject(self) -> None:
118
+ if self._timeout_handler is not None:
119
+ self._timeout_handler.cancel()
120
+ self._timeout_handler = None
121
+
122
+ def shift(self, delay: float) -> None:
123
+ """Advance timeout on delay seconds.
124
+ The delay can be negative.
125
+ Raise RuntimeError if shift is called when deadline is not scheduled
126
+ """
127
+ deadline = self._deadline
128
+ if deadline is None:
129
+ raise RuntimeError(
130
+ "cannot shift timeout if deadline is not scheduled")
131
+ self.update(deadline + delay)
132
+
133
+ def update(self, deadline: float) -> None:
134
+ """Set deadline to absolute value.
135
+ deadline argument points on the time in the same clock system
136
+ as loop.time().
137
+ If new deadline is in the past the timeout is raised immediately.
138
+ Please note: it is not POSIX time but a time with
139
+ undefined starting base, e.g. the time of the system power on.
140
+ """
141
+ if self._state == _State.EXIT:
142
+ raise RuntimeError(
143
+ "cannot reschedule after exit from context manager")
144
+ if self._state == _State.TIMEOUT:
145
+ raise RuntimeError("cannot reschedule expired timeout")
146
+ if self._timeout_handler is not None:
147
+ self._timeout_handler.cancel()
148
+ self._deadline = deadline
149
+ if self._state != _State.INIT:
150
+ self._reschedule()
151
+
152
+ def _reschedule(self) -> None:
153
+ assert self._state == _State.ENTER
154
+ deadline = self._deadline
155
+ if deadline is None:
156
+ return
157
+
158
+ now = self._loop.time()
159
+ if self._timeout_handler is not None:
160
+ self._timeout_handler.cancel()
161
+
162
+ task = asyncio.current_task()
163
+ if deadline <= now:
164
+ self._timeout_handler = self._loop.call_soon(
165
+ self._on_timeout, task)
166
+ else:
167
+ self._timeout_handler = self._loop.call_at(
168
+ deadline, self._on_timeout, task)
169
+
170
+ def _do_enter(self) -> None:
171
+ if self._state != _State.INIT:
172
+ raise RuntimeError(f"invalid state {self._state.value}")
173
+ self._state = _State.ENTER
174
+ self._reschedule()
175
+
176
+ def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
177
+ if exc_type is asyncio.CancelledError and \
178
+ self._state == _State.TIMEOUT:
179
+ self._timeout_handler = None
180
+ raise asyncio.TimeoutError
181
+ # timeout has not expired
182
+ self._state = _State.EXIT
183
+ self._reject()
184
+ return None
185
+
186
+ def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
187
+ if task:
188
+ task.cancel()
189
+ self._state = _State.TIMEOUT
190
+ # drop the reference early
191
+ self._timeout_handler = None
.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py ADDED
@@ -0,0 +1,2025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import copy
4
+ import time
5
+ from collections import Counter as collectionsCounter
6
+ from collections import deque
7
+ from contextlib import contextmanager
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
11
+ List, Mapping, NamedTuple, Optional)
12
+ from typing import Sequence as GenericSequence
13
+ from typing import Set, Type, Union, cast, overload
14
+
15
+ import torch
16
+ from typing_extensions import TypeVar, deprecated
17
+
18
+ import vllm.envs as envs
19
+ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
20
+ ObservabilityConfig, ParallelConfig, SchedulerConfig,
21
+ VllmConfig)
22
+ from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
23
+ SchedulerOutputs)
24
+ from vllm.engine.arg_utils import EngineArgs
25
+ from vllm.engine.metrics_types import StatLoggerBase, Stats
26
+ from vllm.engine.output_processor.interfaces import (
27
+ SequenceGroupOutputProcessor)
28
+ from vllm.engine.output_processor.stop_checker import StopChecker
29
+ from vllm.engine.output_processor.util import create_output_by_sequence_group
30
+ from vllm.entrypoints.openai.logits_processors import (
31
+ get_logits_processors as get_openai_logits_processors)
32
+ from vllm.executor.executor_base import ExecutorBase
33
+ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
34
+ PromptType, SingletonInputsAdapter)
35
+ from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
36
+ from vllm.inputs.preprocess import InputPreprocessor
37
+ from vllm.logger import init_logger
38
+ from vllm.logits_process import get_bad_words_logits_processors
39
+ from vllm.lora.request import LoRARequest
40
+ from vllm.model_executor.guided_decoding import (
41
+ get_local_guided_decoding_logits_processor)
42
+ from vllm.model_executor.layers.sampler import SamplerOutput
43
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
44
+ from vllm.outputs import (PoolingRequestOutput, RequestOutput,
45
+ RequestOutputFactory)
46
+ from vllm.pooling_params import PoolingParams
47
+ from vllm.prompt_adapter.request import PromptAdapterRequest
48
+ from vllm.sampling_params import RequestOutputKind, SamplingParams
49
+ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
50
+ PoolingSequenceGroupOutput, Sequence, SequenceGroup,
51
+ SequenceGroupBase, SequenceGroupMetadata,
52
+ SequenceGroupOutput, SequenceStatus)
53
+ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
54
+ init_tracer)
55
+ from vllm.transformers_utils.detokenizer import Detokenizer
56
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
57
+ from vllm.transformers_utils.tokenizer_group import (
58
+ BaseTokenizerGroup, init_tokenizer_from_configs)
59
+ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
60
+ usage_message)
61
+ from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
62
+ from vllm.version import __version__ as VLLM_VERSION
63
+
64
+ logger = init_logger(__name__)
65
+ _LOCAL_LOGGING_INTERVAL_SEC = 5
66
+
67
+ _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
68
+ _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
69
+
70
+
71
+ @dataclass
72
+ class SchedulerOutputState:
73
+ """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
74
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
75
+ scheduler_outputs: Optional[SchedulerOutputs] = None
76
+ allow_async_output_proc: bool = False
77
+ last_output: Optional[SamplerOutput] = None
78
+
79
+
80
+ class OutputData(NamedTuple):
81
+ outputs: List[SamplerOutput]
82
+ seq_group_metadata_list: List[SequenceGroupMetadata]
83
+ scheduler_outputs: SchedulerOutputs
84
+ is_async: bool
85
+ is_last_step: bool
86
+ # Indicates if this output is from the first step of the
87
+ # multi-step. When multi-step is disabled, this is always
88
+ # set to True.
89
+ # is_first_step_output is invalid when `outputs` has
90
+ # outputs from multiple steps.
91
+ is_first_step_output: Optional[bool]
92
+ skip: List[int]
93
+
94
+
95
+ class SchedulerContext:
96
+
97
+ def __init__(self, multi_step_stream_outputs: bool = False):
98
+ self.output_queue: Deque[OutputData] = deque()
99
+ self.request_outputs: List[Union[RequestOutput,
100
+ PoolingRequestOutput]] = []
101
+ self.seq_group_metadata_list: Optional[
102
+ List[SequenceGroupMetadata]] = None
103
+ self.scheduler_outputs: Optional[SchedulerOutputs] = None
104
+
105
+ self.multi_step_stream_outputs: bool = multi_step_stream_outputs
106
+
107
+ def append_output(self, outputs: List[SamplerOutput],
108
+ seq_group_metadata_list: List[SequenceGroupMetadata],
109
+ scheduler_outputs: SchedulerOutputs, is_async: bool,
110
+ is_last_step: bool,
111
+ is_first_step_output: Optional[bool]):
112
+ self.output_queue.append(
113
+ OutputData(outputs=outputs,
114
+ seq_group_metadata_list=seq_group_metadata_list,
115
+ scheduler_outputs=scheduler_outputs,
116
+ is_async=is_async,
117
+ is_last_step=is_last_step,
118
+ is_first_step_output=is_first_step_output,
119
+ skip=[]))
120
+
121
+
122
+ class LLMEngine:
123
+ """An LLM engine that receives requests and generates texts.
124
+
125
+ This is the main class for the vLLM engine. It receives requests
126
+ from clients and generates texts from the LLM. It includes a tokenizer, a
127
+ language model (possibly distributed across multiple GPUs), and GPU memory
128
+ space allocated for intermediate states (aka KV cache). This class utilizes
129
+ iteration-level scheduling and efficient memory management to maximize the
130
+ serving throughput.
131
+
132
+ The :class:`~vllm.LLM` class wraps this class for offline batched inference
133
+ and the :class:`AsyncLLMEngine` class wraps this class for online serving.
134
+
135
+ The config arguments are derived from :class:`~vllm.EngineArgs`. (See
136
+ :ref:`engine-args`)
137
+
138
+ Args:
139
+ model_config: The configuration related to the LLM model.
140
+ cache_config: The configuration related to the KV cache memory
141
+ management.
142
+ parallel_config: The configuration related to distributed execution.
143
+ scheduler_config: The configuration related to the request scheduler.
144
+ device_config: The configuration related to the device.
145
+ lora_config (Optional): The configuration related to serving multi-LoRA.
146
+ speculative_config (Optional): The configuration related to speculative
147
+ decoding.
148
+ executor_class: The model executor class for managing distributed
149
+ execution.
150
+ prompt_adapter_config (Optional): The configuration related to serving
151
+ prompt adapters.
152
+ log_stats: Whether to log statistics.
153
+ usage_context: Specified entry point, used for usage info collection.
154
+ """
155
+
156
+ DO_VALIDATE_OUTPUT: ClassVar[bool] = False
157
+ """A flag to toggle whether to validate the type of request output."""
158
+
159
+ @classmethod
160
+ @contextmanager
161
+ def enable_output_validation(cls):
162
+ cls.DO_VALIDATE_OUTPUT = True
163
+
164
+ yield
165
+
166
+ cls.DO_VALIDATE_OUTPUT = False
167
+
168
+ @classmethod
169
+ def validate_output(
170
+ cls,
171
+ output: object,
172
+ output_type: Type[_O],
173
+ ) -> _O:
174
+ do_validate = cls.DO_VALIDATE_OUTPUT
175
+
176
+ if ((TYPE_CHECKING or do_validate)
177
+ and not isinstance(output, output_type)):
178
+ raise TypeError(f"Expected output of type {output_type}, "
179
+ f"but found type {type(output)}")
180
+
181
+ return cast(_O, output)
182
+
183
+ @classmethod
184
+ def validate_outputs(
185
+ cls,
186
+ outputs: GenericSequence[object],
187
+ output_type: Type[_O],
188
+ ) -> List[_O]:
189
+ do_validate = cls.DO_VALIDATE_OUTPUT
190
+
191
+ outputs_: List[_O]
192
+ if TYPE_CHECKING or do_validate:
193
+ outputs_ = []
194
+ for output in outputs:
195
+ if not isinstance(output, output_type):
196
+ raise TypeError(f"Expected output of type {output_type}, "
197
+ f"but found type {type(output)}")
198
+
199
+ outputs_.append(output)
200
+ else:
201
+ outputs_ = outputs
202
+
203
+ return outputs_
204
+
205
+ tokenizer: Optional[BaseTokenizerGroup]
206
+
207
+ def __init__(
208
+ self,
209
+ vllm_config: VllmConfig,
210
+ executor_class: Type[ExecutorBase],
211
+ log_stats: bool,
212
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
213
+ stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
214
+ input_registry: InputRegistry = INPUT_REGISTRY,
215
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
216
+ use_cached_outputs: bool = False,
217
+ ) -> None:
218
+
219
+ self.vllm_config = vllm_config
220
+ self.model_config = vllm_config.model_config
221
+ self.cache_config = vllm_config.cache_config
222
+ self.lora_config = vllm_config.lora_config
223
+ self.parallel_config = vllm_config.parallel_config
224
+ self.scheduler_config = vllm_config.scheduler_config
225
+ self.device_config = vllm_config.device_config
226
+ self.speculative_config = vllm_config.speculative_config # noqa
227
+ self.load_config = vllm_config.load_config
228
+ self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
229
+ )
230
+ self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
231
+ self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
232
+ )
233
+
234
+ logger.info(
235
+ "Initializing a V0 LLM engine (v%s) with config: %s, "
236
+ "use_cached_outputs=%s, ",
237
+ VLLM_VERSION,
238
+ vllm_config,
239
+ use_cached_outputs,
240
+ )
241
+
242
+ self.log_stats = log_stats
243
+ self.use_cached_outputs = use_cached_outputs
244
+
245
+ if not self.model_config.skip_tokenizer_init:
246
+ self.tokenizer = self._init_tokenizer()
247
+ self.detokenizer = Detokenizer(self.tokenizer)
248
+ tokenizer_group = self.get_tokenizer_group()
249
+ else:
250
+ self.tokenizer = None
251
+ self.detokenizer = None
252
+ tokenizer_group = None
253
+
254
+ # Ensure that the function doesn't contain a reference to self,
255
+ # to avoid engine GC issues
256
+ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
257
+ assert tokenizer_group, ("tokenizer_group cannot be None, "
258
+ "make sure skip_tokenizer_init is False")
259
+ return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
260
+
261
+ self.seq_counter = Counter()
262
+ self.generation_config_fields = (
263
+ self.model_config.try_get_generation_config())
264
+
265
+ self.input_preprocessor = InputPreprocessor(self.model_config,
266
+ self.tokenizer,
267
+ mm_registry)
268
+
269
+ self.input_registry = input_registry
270
+ self.input_processor = input_registry.create_input_processor(
271
+ self.model_config)
272
+
273
+ self.model_executor = executor_class(vllm_config=vllm_config, )
274
+
275
+ if self.model_config.runner_type != "pooling":
276
+ self._initialize_kv_caches()
277
+
278
+ # If usage stat is enabled, collect relevant info.
279
+ if is_usage_stats_enabled():
280
+ from vllm.model_executor.model_loader import (
281
+ get_architecture_class_name)
282
+ usage_message.report_usage(
283
+ get_architecture_class_name(self.model_config),
284
+ usage_context,
285
+ extra_kvs={
286
+ # Common configuration
287
+ "dtype":
288
+ str(self.model_config.dtype),
289
+ "tensor_parallel_size":
290
+ self.parallel_config.tensor_parallel_size,
291
+ "block_size":
292
+ self.cache_config.block_size,
293
+ "gpu_memory_utilization":
294
+ self.cache_config.gpu_memory_utilization,
295
+
296
+ # Quantization
297
+ "quantization":
298
+ self.model_config.quantization,
299
+ "kv_cache_dtype":
300
+ str(self.cache_config.cache_dtype),
301
+
302
+ # Feature flags
303
+ "enable_lora":
304
+ bool(self.lora_config),
305
+ "enable_prompt_adapter":
306
+ bool(self.prompt_adapter_config),
307
+ "enable_prefix_caching":
308
+ self.cache_config.enable_prefix_caching,
309
+ "enforce_eager":
310
+ self.model_config.enforce_eager,
311
+ "disable_custom_all_reduce":
312
+ self.parallel_config.disable_custom_all_reduce,
313
+ })
314
+
315
+ if self.tokenizer:
316
+ # Ping the tokenizer to ensure liveness if it runs in a
317
+ # different process.
318
+ self.tokenizer.ping()
319
+
320
+ self.cached_scheduler_outputs = [
321
+ SchedulerOutputState()
322
+ for _ in range(self.parallel_config.pipeline_parallel_size)
323
+ ]
324
+
325
+ self.scheduler_contexts = [
326
+ SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
327
+ multi_step_stream_outputs)
328
+ for _ in range(self.parallel_config.pipeline_parallel_size)
329
+ ]
330
+
331
+ if self.model_config.use_async_output_proc:
332
+ process_model_outputs = weak_bind(self._process_model_outputs)
333
+
334
+ self.async_callbacks = [
335
+ partial(process_model_outputs,
336
+ ctx=self.scheduler_contexts[v_id])
337
+ for v_id in range(self.parallel_config.pipeline_parallel_size)
338
+ ]
339
+ else:
340
+ self.async_callbacks = []
341
+
342
+ # Currently used by AsyncLLMEngine to ensure quick append
343
+ # of request outputs to asyncio queues
344
+ self.process_request_outputs_callback: Optional[Callable] = None
345
+
346
+ # Create the scheduler.
347
+ # NOTE: the cache_config here have been updated with the numbers of
348
+ # GPU and CPU blocks, which are profiled in the distributed executor.
349
+ self.scheduler = [
350
+ Scheduler(
351
+ self.scheduler_config, self.cache_config, self.lora_config,
352
+ self.parallel_config.pipeline_parallel_size,
353
+ self.async_callbacks[v_id]
354
+ if self.model_config.use_async_output_proc else None)
355
+ for v_id in range(self.parallel_config.pipeline_parallel_size)
356
+ ]
357
+
358
+ # Metric Logging.
359
+ if self.log_stats:
360
+ if stat_loggers is not None:
361
+ self.stat_loggers = stat_loggers
362
+ else:
363
+ # Lazy import for prometheus multiprocessing.
364
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
365
+ # before prometheus_client is imported.
366
+ # See https://prometheus.github.io/client_python/multiprocess/
367
+ from vllm.engine.metrics import (LoggingStatLogger,
368
+ PrometheusStatLogger)
369
+
370
+ self.stat_loggers = {
371
+ "logging":
372
+ LoggingStatLogger(
373
+ local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
374
+ vllm_config=vllm_config),
375
+ "prometheus":
376
+ PrometheusStatLogger(
377
+ local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
378
+ labels=dict(
379
+ model_name=self.model_config.served_model_name),
380
+ vllm_config=vllm_config),
381
+ }
382
+ self.stat_loggers["prometheus"].info("cache_config",
383
+ self.cache_config)
384
+
385
+ self.tracer = None
386
+ if self.observability_config.otlp_traces_endpoint:
387
+ self.tracer = init_tracer(
388
+ "vllm.llm_engine",
389
+ self.observability_config.otlp_traces_endpoint)
390
+
391
+ # Create sequence output processor, e.g. for beam search or
392
+ # speculative decoding.
393
+ self.output_processor = (
394
+ SequenceGroupOutputProcessor.create_output_processor(
395
+ self.scheduler_config,
396
+ self.detokenizer,
397
+ self.scheduler,
398
+ self.seq_counter,
399
+ get_tokenizer_for_seq,
400
+ stop_checker=StopChecker(
401
+ self.scheduler_config.max_model_len,
402
+ get_tokenizer_for_seq,
403
+ ),
404
+ ))
405
+
406
+ self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
407
+
408
+ def _initialize_kv_caches(self) -> None:
409
+ """Initialize the KV cache in the worker(s).
410
+
411
+ The workers will determine the number of blocks in both the GPU cache
412
+ and the swap CPU cache.
413
+ """
414
+ start = time.time()
415
+ num_gpu_blocks, num_cpu_blocks = (
416
+ self.model_executor.determine_num_available_blocks())
417
+
418
+ if self.cache_config.num_gpu_blocks_override is not None:
419
+ num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
420
+ logger.info(
421
+ "Overriding num_gpu_blocks=%d with "
422
+ "num_gpu_blocks_override=%d", num_gpu_blocks,
423
+ num_gpu_blocks_override)
424
+ num_gpu_blocks = num_gpu_blocks_override
425
+
426
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
427
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
428
+
429
+ self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
430
+ elapsed = time.time() - start
431
+ logger.info(("init engine (profile, create kv cache, "
432
+ "warmup model) took %.2f seconds"), elapsed)
433
+
434
+ @classmethod
435
+ def _get_executor_cls(cls,
436
+ engine_config: VllmConfig) -> Type[ExecutorBase]:
437
+ distributed_executor_backend = (
438
+ engine_config.parallel_config.distributed_executor_backend)
439
+ # Initialize the cluster and specify the executor class.
440
+ if isinstance(distributed_executor_backend, type):
441
+ if not issubclass(distributed_executor_backend, ExecutorBase):
442
+ raise TypeError(
443
+ "distributed_executor_backend must be a subclass of "
444
+ f"ExecutorBase. Got {distributed_executor_backend}.")
445
+ executor_class = distributed_executor_backend
446
+ elif engine_config.parallel_config.world_size > 1:
447
+ if distributed_executor_backend == "ray":
448
+ from vllm.executor.ray_distributed_executor import (
449
+ RayDistributedExecutor)
450
+ executor_class = RayDistributedExecutor
451
+ elif distributed_executor_backend == "mp":
452
+ from vllm.executor.mp_distributed_executor import (
453
+ MultiprocessingDistributedExecutor)
454
+ assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
455
+ "multiprocessing distributed executor backend does not "
456
+ "support VLLM_USE_RAY_SPMD_WORKER=1")
457
+ executor_class = MultiprocessingDistributedExecutor
458
+ elif distributed_executor_backend == "uni":
459
+ # JAX-style, single-process, multi-device executor.
460
+ from vllm.executor.uniproc_executor import UniProcExecutor
461
+ executor_class = UniProcExecutor
462
+ elif distributed_executor_backend == "external_launcher":
463
+ # executor with external launcher
464
+ from vllm.executor.uniproc_executor import ( # noqa
465
+ ExecutorWithExternalLauncher)
466
+ executor_class = ExecutorWithExternalLauncher
467
+ else:
468
+ from vllm.executor.uniproc_executor import UniProcExecutor
469
+ executor_class = UniProcExecutor
470
+ return executor_class
471
+
472
+ @classmethod
473
+ def from_engine_args(
474
+ cls,
475
+ engine_args: EngineArgs,
476
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
477
+ stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
478
+ ) -> "LLMEngine":
479
+ """Creates an LLM engine from the engine arguments."""
480
+ # Create the engine configs.
481
+ engine_config = engine_args.create_engine_config(usage_context)
482
+ executor_class = cls._get_executor_cls(engine_config)
483
+ # Create the LLM engine.
484
+ engine = cls(
485
+ vllm_config=engine_config,
486
+ executor_class=executor_class,
487
+ log_stats=not engine_args.disable_log_stats,
488
+ usage_context=usage_context,
489
+ stat_loggers=stat_loggers,
490
+ )
491
+
492
+ return engine
493
+
494
+ def __reduce__(self):
495
+ # This is to ensure that the LLMEngine is not referenced in
496
+ # the closure used to initialize Ray worker actors
497
+ raise RuntimeError("LLMEngine should not be pickled!")
498
+
499
+ def __del__(self):
500
+ # Shutdown model executor when engine is garbage collected
501
+ # Use getattr since __init__ can fail before the field is set
502
+ if model_executor := getattr(self, "model_executor", None):
503
+ model_executor.shutdown()
504
+
505
+ def get_tokenizer_group(
506
+ self,
507
+ group_type: Type[_G] = BaseTokenizerGroup,
508
+ ) -> _G:
509
+ tokenizer_group = self.tokenizer
510
+
511
+ if tokenizer_group is None:
512
+ raise ValueError("Unable to get tokenizer because "
513
+ "skip_tokenizer_init is True")
514
+ if not isinstance(tokenizer_group, group_type):
515
+ raise TypeError("Invalid type of tokenizer group. "
516
+ f"Expected type: {group_type}, but "
517
+ f"found type: {type(tokenizer_group)}")
518
+
519
+ return tokenizer_group
520
+
521
+ def get_tokenizer(
522
+ self,
523
+ lora_request: Optional[LoRARequest] = None,
524
+ ) -> AnyTokenizer:
525
+ return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
526
+
527
+ def _init_tokenizer(self) -> BaseTokenizerGroup:
528
+ return init_tokenizer_from_configs(
529
+ model_config=self.model_config,
530
+ scheduler_config=self.scheduler_config,
531
+ parallel_config=self.parallel_config,
532
+ lora_config=self.lora_config)
533
+
534
+ def _verify_args(self) -> None:
535
+ self.model_config.verify_with_parallel_config(self.parallel_config)
536
+ self.cache_config.verify_with_parallel_config(self.parallel_config)
537
+ if self.lora_config:
538
+ self.lora_config.verify_with_model_config(self.model_config)
539
+ self.lora_config.verify_with_scheduler_config(
540
+ self.scheduler_config)
541
+ if self.prompt_adapter_config:
542
+ self.prompt_adapter_config.verify_with_model_config(
543
+ self.model_config)
544
+
545
+ def _add_processed_request(
546
+ self,
547
+ request_id: str,
548
+ processed_inputs: ProcessorInputs,
549
+ params: Union[SamplingParams, PoolingParams],
550
+ arrival_time: float,
551
+ lora_request: Optional[LoRARequest],
552
+ prompt_adapter_request: Optional[PromptAdapterRequest],
553
+ trace_headers: Optional[Mapping[str, str]] = None,
554
+ priority: int = 0,
555
+ ) -> Optional[SequenceGroup]:
556
+ """Add a processed request to the engine's request pool.
557
+ return the created sequence group.
558
+ """
559
+ if isinstance(params, SamplingParams) and params.n > 1:
560
+ ParallelSampleSequenceGroup.add_request(
561
+ request_id,
562
+ self,
563
+ params,
564
+ processed_inputs=processed_inputs,
565
+ arrival_time=arrival_time,
566
+ lora_request=lora_request,
567
+ trace_headers=trace_headers,
568
+ prompt_adapter_request=prompt_adapter_request,
569
+ priority=priority,
570
+ )
571
+ return None
572
+
573
+ self._validate_model_inputs(processed_inputs, lora_request)
574
+ # Create the sequences.
575
+ block_size = self.cache_config.block_size
576
+ seq_id = next(self.seq_counter)
577
+ eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
578
+
579
+ if is_encoder_decoder_inputs(processed_inputs):
580
+ decoder_inputs = processed_inputs["decoder"]
581
+ encoder_inputs = processed_inputs["encoder"]
582
+ else:
583
+ decoder_inputs = processed_inputs
584
+ encoder_inputs = None
585
+
586
+ seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
587
+ lora_request, prompt_adapter_request)
588
+
589
+ encoder_seq = (None if encoder_inputs is None else Sequence(
590
+ seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
591
+ prompt_adapter_request))
592
+
593
+ # Create a SequenceGroup based on SamplingParams or PoolingParams
594
+ if isinstance(params, SamplingParams):
595
+ seq_group = self._create_sequence_group_with_sampling(
596
+ request_id,
597
+ seq,
598
+ params,
599
+ arrival_time=arrival_time,
600
+ lora_request=lora_request,
601
+ trace_headers=trace_headers,
602
+ prompt_adapter_request=prompt_adapter_request,
603
+ encoder_seq=encoder_seq,
604
+ priority=priority)
605
+ elif isinstance(params, PoolingParams):
606
+ seq_group = self._create_sequence_group_with_pooling(
607
+ request_id,
608
+ seq,
609
+ params,
610
+ arrival_time=arrival_time,
611
+ lora_request=lora_request,
612
+ prompt_adapter_request=prompt_adapter_request,
613
+ encoder_seq=encoder_seq,
614
+ priority=priority)
615
+ else:
616
+ raise ValueError(
617
+ "Either SamplingParams or PoolingParams must be provided.")
618
+
619
+ # Add the sequence group to the scheduler with least unfinished seqs.
620
+ costs = [
621
+ scheduler.get_num_unfinished_seq_groups()
622
+ for scheduler in self.scheduler
623
+ ]
624
+ min_cost_scheduler = self.scheduler[costs.index(min(costs))]
625
+ min_cost_scheduler.add_seq_group(seq_group)
626
+
627
+ return seq_group
628
+
629
+ def stop_remote_worker_execution_loop(self) -> None:
630
+ self.model_executor.stop_remote_worker_execution_loop()
631
+
632
+ @overload
633
+ def add_request(
634
+ self,
635
+ request_id: str,
636
+ prompt: PromptType,
637
+ params: Union[SamplingParams, PoolingParams],
638
+ arrival_time: Optional[float] = None,
639
+ lora_request: Optional[LoRARequest] = None,
640
+ trace_headers: Optional[Mapping[str, str]] = None,
641
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
642
+ priority: int = 0,
643
+ ) -> None:
644
+ ...
645
+
646
+ @overload
647
+ @deprecated("'inputs' will be renamed to 'prompt")
648
+ def add_request(
649
+ self,
650
+ request_id: str,
651
+ *,
652
+ inputs: PromptType,
653
+ params: Union[SamplingParams, PoolingParams],
654
+ arrival_time: Optional[float] = None,
655
+ lora_request: Optional[LoRARequest] = None,
656
+ trace_headers: Optional[Mapping[str, str]] = None,
657
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
658
+ priority: int = 0,
659
+ ) -> None:
660
+ ...
661
+
662
+ @deprecate_kwargs(
663
+ "inputs",
664
+ additional_message="Please use the 'prompt' parameter instead.",
665
+ )
666
+ def add_request(
667
+ self,
668
+ request_id: str,
669
+ prompt: Optional[PromptType] = None,
670
+ params: Optional[Union[SamplingParams, PoolingParams]] = None,
671
+ arrival_time: Optional[float] = None,
672
+ lora_request: Optional[LoRARequest] = None,
673
+ trace_headers: Optional[Mapping[str, str]] = None,
674
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
675
+ priority: int = 0,
676
+ *,
677
+ inputs: Optional[PromptType] = None, # DEPRECATED
678
+ ) -> None:
679
+ """Add a request to the engine's request pool.
680
+
681
+ The request is added to the request pool and will be processed by the
682
+ scheduler as `engine.step()` is called. The exact scheduling policy is
683
+ determined by the scheduler.
684
+
685
+ Args:
686
+ request_id: The unique ID of the request.
687
+ prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
688
+ for more details about the format of each input.
689
+ params: Parameters for sampling or pooling.
690
+ :class:`~vllm.SamplingParams` for text generation.
691
+ :class:`~vllm.PoolingParams` for pooling.
692
+ arrival_time: The arrival time of the request. If None, we use
693
+ the current monotonic time.
694
+ lora_request: The LoRA request to add.
695
+ trace_headers: OpenTelemetry trace headers.
696
+ prompt_adapter_request: The prompt adapter request to add.
697
+ priority: The priority of the request.
698
+ Only applicable with priority scheduling.
699
+
700
+ Details:
701
+ - Set arrival_time to the current time if it is None.
702
+ - Set prompt_token_ids to the encoded prompt if it is None.
703
+ - Create `n` number of :class:`~vllm.Sequence` objects.
704
+ - Create a :class:`~vllm.SequenceGroup` object
705
+ from the list of :class:`~vllm.Sequence`.
706
+ - Add the :class:`~vllm.SequenceGroup` object to the scheduler.
707
+
708
+ Example:
709
+ >>> # initialize engine
710
+ >>> engine = LLMEngine.from_engine_args(engine_args)
711
+ >>> # set request arguments
712
+ >>> example_prompt = "Who is the president of the United States?"
713
+ >>> sampling_params = SamplingParams(temperature=0.0)
714
+ >>> request_id = 0
715
+ >>>
716
+ >>> # add the request to the engine
717
+ >>> engine.add_request(
718
+ >>> str(request_id),
719
+ >>> example_prompt,
720
+ >>> SamplingParams(temperature=0.0))
721
+ >>> # continue the request processing
722
+ >>> ...
723
+ """
724
+ if inputs is not None:
725
+ prompt = inputs
726
+ assert prompt is not None and params is not None
727
+
728
+ if lora_request is not None and not self.lora_config:
729
+ raise ValueError(f"Got lora_request {lora_request} but LoRA is "
730
+ "not enabled!")
731
+
732
+ if priority != 0 and not self.scheduler_config.policy == "priority":
733
+ raise ValueError(f"Got priority {priority} but "
734
+ "Priority scheduling is not enabled.")
735
+
736
+ if isinstance(params, SamplingParams) \
737
+ and (params.guided_decoding or params.logits_processors) \
738
+ and self.scheduler_config.num_scheduler_steps > 1:
739
+ raise ValueError(
740
+ "Guided decoding and logits processors are not supported "
741
+ "in multi-step decoding")
742
+
743
+ if arrival_time is None:
744
+ arrival_time = time.time()
745
+
746
+ if self.tokenizer is not None:
747
+ self._validate_token_prompt(
748
+ prompt,
749
+ tokenizer=self.get_tokenizer(lora_request=lora_request))
750
+
751
+ preprocessed_inputs = self.input_preprocessor.preprocess(
752
+ prompt,
753
+ request_id=request_id,
754
+ lora_request=lora_request,
755
+ prompt_adapter_request=prompt_adapter_request,
756
+ )
757
+ processed_inputs = self.input_processor(preprocessed_inputs)
758
+
759
+ self._add_processed_request(
760
+ request_id=request_id,
761
+ processed_inputs=processed_inputs,
762
+ params=params,
763
+ arrival_time=arrival_time,
764
+ lora_request=lora_request,
765
+ prompt_adapter_request=prompt_adapter_request,
766
+ trace_headers=trace_headers,
767
+ priority=priority,
768
+ )
769
+
770
+ def _validate_token_prompt(self, prompt: PromptType,
771
+ tokenizer: AnyTokenizer):
772
+ # Guard against out-of-vocab tokens.
773
+ # For some tokenizers, tokenizer.decode will happily return empty text
774
+ # for token ids that are out of vocab, and we don't detect token ids
775
+ # that are greater than the max token id before running the model.
776
+ # However, these token ids will later crash a cuda kernel at runtime
777
+ # with an index out of bounds error. This will crash the entire engine.
778
+ # This needs to happen before multimodal input pre-processing, which
779
+ # may add dummy <image> tokens that aren't part of the tokenizer's
780
+ # vocabulary.
781
+ if is_token_prompt(prompt):
782
+ prompt_ids = prompt["prompt_token_ids"]
783
+ if len(prompt_ids) == 0:
784
+ # Empty prompt check is handled later
785
+ return
786
+ max_input_id = max(prompt_ids)
787
+ if max_input_id > tokenizer.max_token_id:
788
+ raise ValueError(
789
+ "Token id {} is out of vocabulary".format(max_input_id))
790
+
791
+ def _create_sequence_group_with_sampling(
792
+ self,
793
+ request_id: str,
794
+ seq: Sequence,
795
+ sampling_params: SamplingParams,
796
+ arrival_time: float,
797
+ lora_request: Optional[LoRARequest],
798
+ trace_headers: Optional[Mapping[str, str]] = None,
799
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
800
+ encoder_seq: Optional[Sequence] = None,
801
+ priority: int = 0,
802
+ ) -> SequenceGroup:
803
+ """Creates a SequenceGroup with SamplingParams."""
804
+ max_logprobs = self.get_model_config().max_logprobs
805
+ if (sampling_params.logprobs
806
+ and sampling_params.logprobs > max_logprobs) or (
807
+ sampling_params.prompt_logprobs
808
+ and sampling_params.prompt_logprobs > max_logprobs):
809
+ raise ValueError(f"Cannot request more than "
810
+ f"{max_logprobs} logprobs.")
811
+
812
+ sampling_params = self._build_logits_processors(
813
+ sampling_params, lora_request)
814
+
815
+ # Defensive copy of SamplingParams, which are used by the sampler,
816
+ # this doesn't deep-copy LogitsProcessor objects
817
+ sampling_params = sampling_params.clone()
818
+
819
+ sampling_params.update_from_generation_config(
820
+ self.generation_config_fields, seq.eos_token_id)
821
+
822
+ # Create the sequence group.
823
+ seq_group = SequenceGroup(
824
+ request_id=request_id,
825
+ seqs=[seq],
826
+ arrival_time=arrival_time,
827
+ sampling_params=sampling_params,
828
+ lora_request=lora_request,
829
+ trace_headers=trace_headers,
830
+ prompt_adapter_request=prompt_adapter_request,
831
+ encoder_seq=encoder_seq,
832
+ priority=priority)
833
+
834
+ return seq_group
835
+
836
+ def _create_sequence_group_with_pooling(
837
+ self,
838
+ request_id: str,
839
+ seq: Sequence,
840
+ pooling_params: PoolingParams,
841
+ arrival_time: float,
842
+ lora_request: Optional[LoRARequest],
843
+ prompt_adapter_request: Optional[PromptAdapterRequest],
844
+ encoder_seq: Optional[Sequence] = None,
845
+ priority: int = 0,
846
+ ) -> SequenceGroup:
847
+ """Creates a SequenceGroup with PoolingParams."""
848
+ # Defensive copy of PoolingParams, which are used by the pooler
849
+ pooling_params = pooling_params.clone()
850
+ # Create the sequence group.
851
+ seq_group = SequenceGroup(
852
+ request_id=request_id,
853
+ seqs=[seq],
854
+ arrival_time=arrival_time,
855
+ lora_request=lora_request,
856
+ pooling_params=pooling_params,
857
+ prompt_adapter_request=prompt_adapter_request,
858
+ encoder_seq=encoder_seq,
859
+ priority=priority)
860
+ return seq_group
861
+
862
+ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
863
+ """Aborts a request(s) with the given ID.
864
+
865
+ Args:
866
+ request_id: The ID(s) of the request to abort.
867
+
868
+ Details:
869
+ - Refer to the
870
+ :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
871
+ from class :class:`~vllm.core.scheduler.Scheduler`.
872
+
873
+ Example:
874
+ >>> # initialize engine and add a request with request_id
875
+ >>> request_id = str(0)
876
+ >>> # abort the request
877
+ >>> engine.abort_request(request_id)
878
+ """
879
+ for scheduler in self.scheduler:
880
+ scheduler.abort_seq_group(request_id)
881
+
882
+ def get_model_config(self) -> ModelConfig:
883
+ """Gets the model configuration."""
884
+ return self.model_config
885
+
886
+ def get_parallel_config(self) -> ParallelConfig:
887
+ """Gets the parallel configuration."""
888
+ return self.parallel_config
889
+
890
+ def get_decoding_config(self) -> DecodingConfig:
891
+ """Gets the decoding configuration."""
892
+ return self.decoding_config
893
+
894
+ def get_scheduler_config(self) -> SchedulerConfig:
895
+ """Gets the scheduler configuration."""
896
+ return self.scheduler_config
897
+
898
+ def get_lora_config(self) -> LoRAConfig:
899
+ """Gets the LoRA configuration."""
900
+ return self.lora_config
901
+
902
+ def get_num_unfinished_requests(self) -> int:
903
+ """Gets the number of unfinished requests."""
904
+ return sum(scheduler.get_num_unfinished_seq_groups()
905
+ for scheduler in self.scheduler)
906
+
907
+ def has_unfinished_requests(self) -> bool:
908
+ """Returns True if there are unfinished requests."""
909
+ return any(scheduler.has_unfinished_seqs()
910
+ for scheduler in self.scheduler)
911
+
912
+ def has_unfinished_requests_for_virtual_engine(
913
+ self, virtual_engine: int) -> bool:
914
+ """
915
+ Returns True if there are unfinished requests for the virtual engine.
916
+ """
917
+ return self.scheduler[virtual_engine].has_unfinished_seqs()
918
+
919
+ def reset_prefix_cache(self) -> bool:
920
+ """Reset prefix cache for all devices."""
921
+
922
+ success = True
923
+ for scheduler in self.scheduler:
924
+ success = success and scheduler.reset_prefix_cache()
925
+ return success
926
+
927
+ @staticmethod
928
+ def _process_sequence_group_outputs(
929
+ seq_group: SequenceGroup,
930
+ outputs: List[PoolingSequenceGroupOutput],
931
+ ) -> None:
932
+ seq_group.pooled_data = outputs[0].data
933
+
934
+ for seq in seq_group.get_seqs():
935
+ seq.status = SequenceStatus.FINISHED_STOPPED
936
+
937
+ return
938
+
939
+ def _update_num_computed_tokens_for_multi_step_prefill(
940
+ self, seq_group: SequenceGroup,
941
+ seq_group_meta: SequenceGroupMetadata,
942
+ is_first_step_output: Optional[bool]):
943
+ """
944
+ This function updates num_computed_tokens for prompt sequences
945
+ when Multi-Step is enabled.
946
+
947
+ seq_group: SequenceGroup to update the num_computed_tokens for.
948
+ seq_group_meta: Metadata of the given SequenceGroup.
949
+ is_first_step_output: Optional[bool] -
950
+ When available, is_first_step_output indicates if the appended
951
+ output token is the output of the first-step in multi-step.
952
+ A value of None indicates that outputs from all steps in
953
+ in multi-step are submitted in a single burst.
954
+ """
955
+
956
+ assert self.scheduler_config.is_multi_step
957
+
958
+ if not seq_group_meta.is_prompt:
959
+ # num_computed_token updates for multi-step decodes happen after
960
+ # the tokens are appended to the sequence.
961
+ return
962
+
963
+ do_update: bool = False
964
+ if self.scheduler_config.chunked_prefill_enabled:
965
+ # In multi-step + chunked-prefill case, the prompt sequences
966
+ # that are scheduled are fully processed in the first step.
967
+ do_update = is_first_step_output is None or is_first_step_output
968
+ else:
969
+ # Normal multi-step decoding case. In this case prompt-sequences
970
+ # are actually single-stepped. Always update in this case.
971
+ assert seq_group.state.num_steps == 1
972
+ do_update = True
973
+
974
+ if do_update:
975
+ seq_group.update_num_computed_tokens(
976
+ seq_group_meta.token_chunk_size)
977
+
978
+ def _process_model_outputs(self,
979
+ ctx: SchedulerContext,
980
+ request_id: Optional[str] = None) -> None:
981
+ """Apply the model output to the sequences in the scheduled seq groups
982
+ and return responses.
983
+
984
+ ctx: The virtual engine context to work on
985
+ request_id: If provided, then only this request is going to be processed
986
+ """
987
+
988
+ now = time.time()
989
+
990
+ if len(ctx.output_queue) == 0:
991
+ return None
992
+
993
+ # Get pending async postprocessor
994
+ if request_id:
995
+ # When we process only one request, no pop is required
996
+ # (since later we will process all of the rest)
997
+ (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
998
+ is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
999
+ else:
1000
+ (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
1001
+ is_last_step, is_first_step_output,
1002
+ skip) = ctx.output_queue.popleft()
1003
+
1004
+ # Sanity check
1005
+ assert len(seq_group_metadata_list) == len(
1006
+ scheduler_outputs.scheduled_seq_groups)
1007
+
1008
+ has_multiple_outputs: bool = len(outputs) > 1
1009
+ outputs_by_sequence_group: List[List[SequenceGroupOutput]]
1010
+ if has_multiple_outputs:
1011
+ assert self.scheduler_config.is_multi_step or \
1012
+ self.speculative_config
1013
+ # Organize outputs by [step][sequence group] instead of
1014
+ # [sequence group][step].
1015
+ if self.scheduler_config.is_multi_step:
1016
+ outputs_by_sequence_group = create_output_by_sequence_group(
1017
+ outputs, len(seq_group_metadata_list))
1018
+ elif self.speculative_config:
1019
+ # Decodes are multi-steps while prefills are not, outputting at
1020
+ # most 1 token. Separate them so that we can trigger chunk
1021
+ # processing without having to pad or copy over prompts K times
1022
+ # to match decodes structure (costly with prompt_logprobs).
1023
+ num_prefills = sum(sg.is_prompt
1024
+ for sg in seq_group_metadata_list)
1025
+ prefills, decodes = outputs[:num_prefills], outputs[
1026
+ num_prefills:]
1027
+ outputs_by_sequence_group = create_output_by_sequence_group(
1028
+ decodes,
1029
+ num_seq_groups=len(seq_group_metadata_list) - num_prefills)
1030
+ outputs_by_sequence_group = [p.outputs for p in prefills
1031
+ ] + outputs_by_sequence_group
1032
+ # We have outputs for multiple steps submitted in a single burst,
1033
+ # so invalidate is_first_step_output.
1034
+ is_first_step_output = None
1035
+ else:
1036
+ outputs_by_sequence_group = outputs
1037
+
1038
+ # Determine the requests we need to operate on
1039
+ if request_id:
1040
+ indices = []
1041
+ for i, seq_group_meta in enumerate(seq_group_metadata_list):
1042
+ if seq_group_meta.request_id == request_id:
1043
+ assert i not in skip # Cannot be called twice
1044
+ indices.append(i)
1045
+ break
1046
+
1047
+ # If the request_id was not found, then it means that
1048
+ # this is a new request that has no pending async
1049
+ # postprocessor
1050
+ if not indices:
1051
+ return
1052
+ else:
1053
+ indices = range(len(seq_group_metadata_list)) # type: ignore
1054
+
1055
+ finished_before: List[int] = []
1056
+ finished_now: List[int] = []
1057
+ for i in indices:
1058
+ if i in skip:
1059
+ continue
1060
+
1061
+ seq_group_meta = seq_group_metadata_list[i]
1062
+ scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1063
+
1064
+ seq_group: SequenceGroup = scheduled_seq_group.seq_group
1065
+
1066
+ if seq_group.is_finished():
1067
+ finished_before.append(i)
1068
+ continue
1069
+
1070
+ output: List[SequenceGroupOutput]
1071
+ if has_multiple_outputs:
1072
+ output = outputs_by_sequence_group[i]
1073
+ else:
1074
+ output = [outputs_by_sequence_group[0][i]]
1075
+
1076
+ if not is_async:
1077
+ if self.scheduler_config.is_multi_step:
1078
+ # Updates happen only if the sequence is prefill
1079
+ self._update_num_computed_tokens_for_multi_step_prefill(
1080
+ seq_group, seq_group_meta, is_first_step_output)
1081
+ else:
1082
+ seq_group.update_num_computed_tokens(
1083
+ seq_group_meta.token_chunk_size or 0)
1084
+
1085
+ if outputs:
1086
+ for o in outputs:
1087
+ if (isinstance(o, SamplerOutput)
1088
+ and seq_group.metrics is not None):
1089
+ if seq_group.metrics.model_forward_time is not None:
1090
+ seq_group.metrics.model_forward_time += (
1091
+ o.model_forward_time or 0)
1092
+ else:
1093
+ seq_group.metrics.model_forward_time = (
1094
+ o.model_forward_time)
1095
+ if seq_group.metrics.model_execute_time is not None:
1096
+ seq_group.metrics.model_execute_time += (
1097
+ o.model_execute_time or 0)
1098
+ else:
1099
+ seq_group.metrics.model_execute_time = (
1100
+ o.model_execute_time)
1101
+
1102
+ if self.model_config.runner_type == "pooling":
1103
+ self._process_sequence_group_outputs(seq_group, output)
1104
+ else:
1105
+ self.output_processor.process_prompt_logprob(seq_group, output)
1106
+ if seq_group_meta.do_sample:
1107
+ self.output_processor.process_outputs(
1108
+ seq_group, output, is_async)
1109
+
1110
+ if seq_group.is_finished():
1111
+ finished_now.append(i)
1112
+
1113
+ # Generate outputs for the requests that finished this iteration
1114
+ for i in finished_now:
1115
+ scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1116
+
1117
+ seq_group = scheduled_seq_group.seq_group
1118
+ seq_group.maybe_set_first_token_time(now)
1119
+ if not seq_group.is_prefill():
1120
+ seq_group.set_last_token_time(now)
1121
+ request_output = RequestOutputFactory.create(
1122
+ seq_group,
1123
+ self.seq_id_to_seq_group,
1124
+ use_cache=self.use_cached_outputs)
1125
+ if request_output:
1126
+ ctx.request_outputs.append(request_output)
1127
+
1128
+ # When we process a single request, we skip it for the next time,
1129
+ # and invoke the request output callback (if there was final output)
1130
+ if request_id:
1131
+ assert len(indices) == 1
1132
+ skip.append(indices[0])
1133
+
1134
+ if (finished_now
1135
+ and self.process_request_outputs_callback is not None):
1136
+ self.process_request_outputs_callback(ctx.request_outputs)
1137
+ ctx.request_outputs.clear()
1138
+ return
1139
+
1140
+ # Free currently finished requests
1141
+ if finished_now:
1142
+ for scheduler in self.scheduler:
1143
+ scheduler.free_finished_seq_groups()
1144
+
1145
+ # For multi-step without streaming, don't create outputs each iteration
1146
+ if not is_last_step and not ctx.multi_step_stream_outputs:
1147
+ # Immediately process request outputs here (if callback is given)
1148
+ if (finished_now
1149
+ and self.process_request_outputs_callback is not None):
1150
+ self.process_request_outputs_callback(ctx.request_outputs)
1151
+ ctx.request_outputs.clear()
1152
+ return
1153
+
1154
+ # Create the outputs
1155
+ for i in indices:
1156
+ if i in skip or i in finished_before or i in finished_now:
1157
+ continue # Avoids double processing
1158
+
1159
+ scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
1160
+
1161
+ seq_group = scheduled_seq_group.seq_group
1162
+ seq_group.maybe_set_first_token_time(now)
1163
+ if not seq_group.is_prefill():
1164
+ seq_group.set_last_token_time(now)
1165
+ request_output = RequestOutputFactory.create(
1166
+ seq_group,
1167
+ self.seq_id_to_seq_group,
1168
+ use_cache=self.use_cached_outputs)
1169
+ if request_output:
1170
+ ctx.request_outputs.append(request_output)
1171
+
1172
+ # For multi-step with streaming, create outputs each iteration
1173
+ if not is_last_step and ctx.multi_step_stream_outputs:
1174
+ # Immediately process request outputs here (if callback is given)
1175
+ if self.process_request_outputs_callback is not None:
1176
+ self.process_request_outputs_callback(ctx.request_outputs)
1177
+ ctx.request_outputs.clear()
1178
+ return
1179
+
1180
+ for seq_group in scheduler_outputs.ignored_seq_groups:
1181
+ params = seq_group.sampling_params
1182
+ if params is not None and params.output_kind == (
1183
+ RequestOutputKind.DELTA) and not seq_group.is_finished():
1184
+ continue
1185
+
1186
+ request_output = RequestOutputFactory.create(
1187
+ seq_group,
1188
+ self.seq_id_to_seq_group,
1189
+ use_cache=self.use_cached_outputs,
1190
+ )
1191
+ if request_output:
1192
+ ctx.request_outputs.append(request_output)
1193
+
1194
+ # Immediately process request outputs here (if callback is given)
1195
+ if (ctx.request_outputs
1196
+ and self.process_request_outputs_callback is not None):
1197
+ self.process_request_outputs_callback(ctx.request_outputs)
1198
+ ctx.request_outputs.clear()
1199
+
1200
+ # For async case, we need to record the stats here.
1201
+ # For non-async case, the stats are done in the
1202
+ # LLMEngine/AsyncLLMEngine directly
1203
+ if is_async:
1204
+ # Log stats.
1205
+ self.do_log_stats(scheduler_outputs, outputs, finished_before,
1206
+ skip)
1207
+
1208
+ # Tracing
1209
+ self.do_tracing(scheduler_outputs, finished_before)
1210
+
1211
+ return None
1212
+
1213
+ def _advance_to_next_step(
1214
+ self, output: List[SamplerOutput],
1215
+ seq_group_metadata_list: List[SequenceGroupMetadata],
1216
+ scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
1217
+ """Given model output from a single run, append the tokens to the
1218
+ sequences. This is normally done inside output processor, but it is
1219
+ required if the worker is to perform async forward pass to next step.
1220
+ """
1221
+ for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
1222
+ zip(seq_group_metadata_list, output, scheduled_seq_groups):
1223
+ seq_group = scheduled_seq_group.seq_group
1224
+
1225
+ if seq_group.is_finished():
1226
+ continue
1227
+
1228
+ if self.scheduler_config.is_multi_step:
1229
+ # Updates happen only if the sequence is prefill
1230
+ self._update_num_computed_tokens_for_multi_step_prefill(
1231
+ seq_group, seq_group_metadata,
1232
+ seq_group.state.num_steps == 1)
1233
+ else:
1234
+ token_chunk_size = (seq_group_metadata.token_chunk_size
1235
+ if seq_group_metadata.token_chunk_size
1236
+ is not None else 0)
1237
+ seq_group.update_num_computed_tokens(token_chunk_size)
1238
+
1239
+ if seq_group_metadata.do_sample:
1240
+ assert len(sequence_group_outputs.samples) == 1, (
1241
+ "Async output processor expects a single sample"
1242
+ " (i.e sampling_params.n == 1)")
1243
+ sample = sequence_group_outputs.samples[0]
1244
+
1245
+ assert len(seq_group.seqs) == 1
1246
+ seq = seq_group.seqs[0]
1247
+
1248
+ if self.scheduler_config.is_multi_step:
1249
+ is_prefill_append = seq.data.get_num_uncomputed_tokens(
1250
+ ) == 0
1251
+ seq.append_token_id(sample.output_token, sample.logprobs)
1252
+ if not is_prefill_append:
1253
+ seq_group.update_num_computed_tokens(1)
1254
+ else:
1255
+ seq.append_token_id(sample.output_token, sample.logprobs)
1256
+
1257
+ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1258
+ """Performs one decoding iteration and returns newly generated results.
1259
+
1260
+ .. figure:: https://i.imgur.com/sv2HssD.png
1261
+ :alt: Overview of the step function
1262
+ :align: center
1263
+
1264
+ Overview of the step function.
1265
+
1266
+ Details:
1267
+ - Step 1: Schedules the sequences to be executed in the next
1268
+ iteration and the token blocks to be swapped in/out/copy.
1269
+
1270
+ - Depending on the scheduling policy,
1271
+ sequences may be `preempted/reordered`.
1272
+ - A Sequence Group (SG) refer to a group of sequences
1273
+ that are generated from the same prompt.
1274
+
1275
+ - Step 2: Calls the distributed executor to execute the model.
1276
+ - Step 3: Processes the model output. This mainly includes:
1277
+
1278
+ - Decodes the relevant outputs.
1279
+ - Updates the scheduled sequence groups with model outputs
1280
+ based on its `sampling parameters` (`use_beam_search` or not).
1281
+ - Frees the finished sequence groups.
1282
+
1283
+ - Finally, it creates and returns the newly generated results.
1284
+
1285
+ Example:
1286
+ >>> # Please see the example/ folder for more detailed examples.
1287
+ >>>
1288
+ >>> # initialize engine and request arguments
1289
+ >>> engine = LLMEngine.from_engine_args(engine_args)
1290
+ >>> example_inputs = [(0, "What is LLM?",
1291
+ >>> SamplingParams(temperature=0.0))]
1292
+ >>>
1293
+ >>> # Start the engine with an event loop
1294
+ >>> while True:
1295
+ >>> if example_inputs:
1296
+ >>> req_id, prompt, sampling_params = example_inputs.pop(0)
1297
+ >>> engine.add_request(str(req_id),prompt,sampling_params)
1298
+ >>>
1299
+ >>> # continue the request processing
1300
+ >>> request_outputs = engine.step()
1301
+ >>> for request_output in request_outputs:
1302
+ >>> if request_output.finished:
1303
+ >>> # return or show the request output
1304
+ >>>
1305
+ >>> if not (engine.has_unfinished_requests() or example_inputs):
1306
+ >>> break
1307
+ """
1308
+ if self.parallel_config.pipeline_parallel_size > 1:
1309
+ raise NotImplementedError(
1310
+ "Pipeline parallelism is only supported through AsyncLLMEngine "
1311
+ "as performance will be severely degraded otherwise.")
1312
+
1313
+ # For llm_engine, there is no pipeline parallel support, so the engine
1314
+ # used is always 0.
1315
+ virtual_engine = 0
1316
+
1317
+ # These are cached outputs from previous iterations. None if on first
1318
+ # iteration
1319
+ cached_outputs = self.cached_scheduler_outputs[virtual_engine]
1320
+ seq_group_metadata_list = cached_outputs.seq_group_metadata_list
1321
+ scheduler_outputs = cached_outputs.scheduler_outputs
1322
+ allow_async_output_proc = cached_outputs.allow_async_output_proc
1323
+
1324
+ ctx = self.scheduler_contexts[virtual_engine]
1325
+
1326
+ # Clear outputs for each new scheduler iteration
1327
+ ctx.request_outputs.clear()
1328
+
1329
+ # Skip the scheduler if there are any remaining steps in the seq groups.
1330
+ # This ensures that the scheduler is only called again when the current
1331
+ # batch has completed.
1332
+ if not self._has_remaining_steps(seq_group_metadata_list):
1333
+ # Schedule iteration
1334
+ (seq_group_metadata_list, scheduler_outputs,
1335
+ allow_async_output_proc
1336
+ ) = self.scheduler[virtual_engine].schedule()
1337
+
1338
+ ctx.seq_group_metadata_list = seq_group_metadata_list
1339
+ ctx.scheduler_outputs = scheduler_outputs
1340
+
1341
+ finished_requests_ids = self.scheduler[
1342
+ virtual_engine].get_and_reset_finished_requests_ids()
1343
+
1344
+ # Maybe switch from async mode to sync mode
1345
+ if not allow_async_output_proc and len(ctx.output_queue) > 0:
1346
+ self._process_model_outputs(ctx=ctx)
1347
+
1348
+ if (self.scheduler_config.is_multi_step
1349
+ and scheduler_outputs.num_lookahead_slots > 0):
1350
+ # cache the scheduler outputs for the next iteration if we have
1351
+ # lookahead slots
1352
+ self._cache_scheduler_outputs_for_multi_step(
1353
+ virtual_engine, seq_group_metadata_list, scheduler_outputs,
1354
+ allow_async_output_proc)
1355
+ else:
1356
+ finished_requests_ids = list()
1357
+
1358
+ assert seq_group_metadata_list is not None
1359
+ assert scheduler_outputs is not None
1360
+
1361
+ if not scheduler_outputs.is_empty():
1362
+
1363
+ # Check if we have a cached last_output from the previous iteration.
1364
+ # For supporting PP this is probably the best way to pass the
1365
+ # sampled_token_ids, as a separate broadcast over all the PP stages
1366
+ # will cause one virtual engine's microbatch to block the pipeline.
1367
+ last_sampled_token_ids = \
1368
+ self._get_last_sampled_token_ids(virtual_engine)
1369
+
1370
+ execute_model_req = ExecuteModelRequest(
1371
+ seq_group_metadata_list=seq_group_metadata_list,
1372
+ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
1373
+ blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
1374
+ blocks_to_copy=scheduler_outputs.blocks_to_copy,
1375
+ num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
1376
+ running_queue_size=scheduler_outputs.running_queue_size,
1377
+ finished_requests_ids=finished_requests_ids,
1378
+ # We use ExecuteModelRequest to pass the last sampled_token_ids
1379
+ # to each of the non-last PP stages for in-place prepare_input.
1380
+ last_sampled_token_ids=last_sampled_token_ids)
1381
+
1382
+ if allow_async_output_proc:
1383
+ execute_model_req.async_callback = self.async_callbacks[
1384
+ virtual_engine]
1385
+
1386
+ outputs = self.model_executor.execute_model(
1387
+ execute_model_req=execute_model_req)
1388
+
1389
+ # We need to do this here so that last step's sampled_token_ids can
1390
+ # be passed to the next iteration for PP.
1391
+ if self.scheduler_config.is_multi_step:
1392
+ self._update_cached_scheduler_output(virtual_engine, outputs)
1393
+ else:
1394
+ # Nothing scheduled => If there is pending async postprocessor,
1395
+ # then finish it here.
1396
+ if len(ctx.output_queue) > 0:
1397
+ self._process_model_outputs(ctx=ctx)
1398
+ # No outputs in this case
1399
+ outputs = []
1400
+
1401
+ # Finish the current step for all the sequence groups.
1402
+ if self.scheduler_config.is_multi_step:
1403
+ for seq_group in seq_group_metadata_list:
1404
+ seq_group.finish_step()
1405
+
1406
+ if not self._has_remaining_steps(seq_group_metadata_list):
1407
+ # clear the cache if we have finished all the steps.
1408
+ if self.scheduler_config.is_multi_step:
1409
+ self.cached_scheduler_outputs[0] = SchedulerOutputState()
1410
+
1411
+ # is_first_step_output is True only when the num_steps of all
1412
+ # the sequences are 1. When the num_steps > 1,
1413
+ # multi_step_model_runner does the first-step output append.
1414
+ is_first_step_output: bool = False if not seq_group_metadata_list \
1415
+ else seq_group_metadata_list[0].state.num_steps == 1
1416
+
1417
+ # Add results to the output_queue
1418
+ ctx.append_output(outputs=outputs,
1419
+ seq_group_metadata_list=seq_group_metadata_list,
1420
+ scheduler_outputs=scheduler_outputs,
1421
+ is_async=allow_async_output_proc,
1422
+ is_last_step=True,
1423
+ is_first_step_output=is_first_step_output)
1424
+
1425
+ if outputs and allow_async_output_proc:
1426
+ assert len(outputs) == 1, (
1427
+ "Async postprocessor expects only a single output set")
1428
+
1429
+ self._advance_to_next_step(
1430
+ outputs[0], seq_group_metadata_list,
1431
+ scheduler_outputs.scheduled_seq_groups)
1432
+
1433
+ # Check if need to run the usual non-async path
1434
+ if not allow_async_output_proc:
1435
+ self._process_model_outputs(ctx=ctx)
1436
+
1437
+ # Log stats.
1438
+ self.do_log_stats(scheduler_outputs, outputs)
1439
+
1440
+ # Tracing
1441
+ self.do_tracing(scheduler_outputs)
1442
+ else:
1443
+ # Multi-step case
1444
+ return ctx.request_outputs
1445
+
1446
+ if not self.has_unfinished_requests():
1447
+ # Drain async postprocessor (if exists)
1448
+ if len(ctx.output_queue) > 0:
1449
+ self._process_model_outputs(ctx=ctx)
1450
+ assert len(ctx.output_queue) == 0
1451
+
1452
+ # Stop the execute model loop in parallel workers until there are
1453
+ # more requests to process. This avoids waiting indefinitely in
1454
+ # torch.distributed ops which may otherwise timeout, and unblocks
1455
+ # the RPC thread in the workers so that they can process any other
1456
+ # queued control plane messages, such as add/remove lora adapters.
1457
+ logger.debug("Stopping remote worker execution loop.")
1458
+ self.model_executor.stop_remote_worker_execution_loop()
1459
+
1460
+ return ctx.request_outputs
1461
+
1462
+ def _has_remaining_steps(
1463
+ self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
1464
+ ) -> bool:
1465
+ if (not self.scheduler_config.is_multi_step
1466
+ or not seq_group_metadata_list):
1467
+ return False
1468
+
1469
+ # TODO(will) this is a sanity check for nowto make sure that all the
1470
+ # seqs are on the same steps. Eventually we will want to do some sort of
1471
+ # dynamic scheduling when doing multi-step decoding.
1472
+ ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
1473
+ if any([
1474
+ seq_group.state.remaining_steps != ref_remaining_steps
1475
+ for seq_group in seq_group_metadata_list[1:]
1476
+ ]):
1477
+ raise AssertionError("All running sequence groups should "
1478
+ "have the same remaining steps.")
1479
+
1480
+ return ref_remaining_steps > 0
1481
+
1482
+ def _cache_scheduler_outputs_for_multi_step(
1483
+ self, virtual_engine: int,
1484
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
1485
+ scheduler_outputs: SchedulerOutputs,
1486
+ allow_async_output_proc: bool) -> None:
1487
+ co = self.cached_scheduler_outputs[virtual_engine]
1488
+
1489
+ co.seq_group_metadata_list = seq_group_metadata_list
1490
+ co.scheduler_outputs = scheduler_outputs
1491
+ co.allow_async_output_proc = allow_async_output_proc
1492
+ co.last_output = None
1493
+
1494
+ def _update_cached_scheduler_output(
1495
+ self, virtual_engine: int,
1496
+ output: List[Optional[SamplerOutput]]) -> None:
1497
+ if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
1498
+ and output[0] is not None):
1499
+ last_output = output[-1]
1500
+ assert last_output is not None
1501
+ assert last_output.sampled_token_ids_cpu is not None
1502
+ assert last_output.sampled_token_ids is None
1503
+ assert last_output.sampled_token_probs is None
1504
+ self.cached_scheduler_outputs[
1505
+ virtual_engine].last_output = last_output
1506
+
1507
+ def _get_last_sampled_token_ids(
1508
+ self, virtual_engine: int) -> Optional[torch.Tensor]:
1509
+ cached_last_output = self.cached_scheduler_outputs[
1510
+ virtual_engine].last_output
1511
+ if (self.scheduler_config.is_multi_step
1512
+ and self.parallel_config.pipeline_parallel_size > 1
1513
+ and cached_last_output is not None
1514
+ and cached_last_output.sampled_token_ids_cpu is not None):
1515
+ return cached_last_output.sampled_token_ids_cpu
1516
+ return None
1517
+
1518
+ def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1519
+ if not self.log_stats:
1520
+ raise RuntimeError(
1521
+ "Stat logging is disabled. Set `disable_log_stats=False` "
1522
+ "argument to enable.")
1523
+ if logger_name in self.stat_loggers:
1524
+ raise KeyError(f"Logger with name {logger_name} already exists.")
1525
+ self.stat_loggers[logger_name] = logger
1526
+
1527
+ def remove_logger(self, logger_name: str) -> None:
1528
+ if not self.log_stats:
1529
+ raise RuntimeError(
1530
+ "Stat logging is disabled. Set `disable_log_stats=False` "
1531
+ "argument to enable.")
1532
+ if logger_name not in self.stat_loggers:
1533
+ raise KeyError(f"Logger with name {logger_name} does not exist.")
1534
+ del self.stat_loggers[logger_name]
1535
+
1536
+ def do_log_stats(self,
1537
+ scheduler_outputs: Optional[SchedulerOutputs] = None,
1538
+ model_output: Optional[List[SamplerOutput]] = None,
1539
+ finished_before: Optional[List[int]] = None,
1540
+ skip: Optional[List[int]] = None) -> None:
1541
+ """Forced log when no requests active."""
1542
+ if self.log_stats:
1543
+ stats = self._get_stats(scheduler_outputs, model_output,
1544
+ finished_before, skip)
1545
+ for logger in self.stat_loggers.values():
1546
+ logger.log(stats)
1547
+
1548
+ def _get_stats(self,
1549
+ scheduler_outputs: Optional[SchedulerOutputs],
1550
+ model_output: Optional[List[SamplerOutput]] = None,
1551
+ finished_before: Optional[List[int]] = None,
1552
+ skip: Optional[List[int]] = None) -> Stats:
1553
+ """Get Stats to be Logged to Prometheus.
1554
+
1555
+ Args:
1556
+ scheduler_outputs: Optional, used to populate metrics related to
1557
+ the scheduled batch,
1558
+ model_output: Optional, used to emit speculative decoding metrics
1559
+ which are created by the workers.
1560
+ finished_before: Optional, indices of sequences that were finished
1561
+ before. These sequences will be ignored.
1562
+ skip: Optional, indices of sequences that were preempted. These
1563
+ sequences will be ignored.
1564
+ """
1565
+ now = time.time()
1566
+
1567
+ # System State
1568
+ # Scheduler State
1569
+ num_running_sys = sum(
1570
+ len(scheduler.running) for scheduler in self.scheduler)
1571
+ num_swapped_sys = sum(
1572
+ len(scheduler.swapped) for scheduler in self.scheduler)
1573
+ num_waiting_sys = sum(
1574
+ len(scheduler.waiting) for scheduler in self.scheduler)
1575
+
1576
+ # KV Cache Usage in %
1577
+ num_total_gpu = self.cache_config.num_gpu_blocks
1578
+ gpu_cache_usage_sys = 0.
1579
+ if num_total_gpu: # Guard against both None and 0
1580
+ num_free_gpu = sum(
1581
+ scheduler.block_manager.get_num_free_gpu_blocks()
1582
+ for scheduler in self.scheduler)
1583
+ gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
1584
+
1585
+ num_total_cpu = self.cache_config.num_cpu_blocks
1586
+ cpu_cache_usage_sys = 0.
1587
+ if num_total_cpu: # Guard against both None and 0
1588
+ num_free_cpu = sum(
1589
+ scheduler.block_manager.get_num_free_cpu_blocks()
1590
+ for scheduler in self.scheduler)
1591
+ cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
1592
+
1593
+ # Prefix Cache Hit Rate. Note that we always use
1594
+ # the cache hit rate of the first virtual engine.
1595
+ cpu_prefix_cache_hit_rate = self.scheduler[
1596
+ 0].get_prefix_cache_hit_rate(Device.CPU)
1597
+ gpu_prefix_cache_hit_rate = self.scheduler[
1598
+ 0].get_prefix_cache_hit_rate(Device.GPU)
1599
+
1600
+ # Iteration stats
1601
+ num_prompt_tokens_iter = 0
1602
+ num_generation_tokens_iter = 0
1603
+ num_tokens_iter = 0
1604
+ time_to_first_tokens_iter: List[float] = []
1605
+ time_per_output_tokens_iter: List[float] = []
1606
+ num_preemption_iter = (0 if scheduler_outputs is None else
1607
+ scheduler_outputs.preempted)
1608
+
1609
+ # Request stats
1610
+ # Latency
1611
+ time_e2e_requests: List[float] = []
1612
+ time_queue_requests: List[float] = []
1613
+ time_inference_requests: List[float] = []
1614
+ time_prefill_requests: List[float] = []
1615
+ time_decode_requests: List[float] = []
1616
+ time_in_queue_requests: List[float] = []
1617
+ model_forward_time_requests: List[float] = []
1618
+ model_execute_time_requests: List[float] = []
1619
+ # Metadata
1620
+ num_prompt_tokens_requests: List[int] = []
1621
+ num_generation_tokens_requests: List[int] = []
1622
+ n_requests: List[int] = []
1623
+ max_num_generation_tokens_requests: List[int] = []
1624
+ max_tokens_requests: List[int] = []
1625
+ finished_reason_requests: List[str] = []
1626
+
1627
+ # Lora requests
1628
+ running_lora_adapters = dict(
1629
+ collectionsCounter([
1630
+ running_request.lora_request.lora_name
1631
+ for scheduler in self.scheduler
1632
+ for running_request in scheduler.running
1633
+ if running_request.lora_request
1634
+ ]))
1635
+ waiting_lora_adapters = dict(
1636
+ collectionsCounter([
1637
+ waiting_request.lora_request.lora_name
1638
+ for scheduler in self.scheduler
1639
+ for waiting_request in scheduler.waiting
1640
+ if waiting_request.lora_request
1641
+ ]))
1642
+ max_lora_stat = "0"
1643
+ if self.lora_config:
1644
+ max_lora_stat = str(self.lora_config.max_loras)
1645
+
1646
+ # NOTE: This loop assumes prefill seq_groups are before
1647
+ # decode seq_groups in scheduled_seq_groups.
1648
+ if scheduler_outputs is not None:
1649
+ # For async postprocessor, already finished sequences need to be
1650
+ # not counted (to avoid double counting)
1651
+ actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
1652
+
1653
+ num_generation_tokens_from_prefill_groups = 0
1654
+ # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
1655
+ # the len of scheduler_outputs.scheduled_seq_groups is !=
1656
+ # scheduler_outputs.num_prefill_groups, this means that
1657
+ # chunked prefills have been detected.
1658
+
1659
+ for idx, scheduled_seq_group in enumerate(
1660
+ scheduler_outputs.scheduled_seq_groups):
1661
+ # Skip double logging when using async output proc
1662
+ if finished_before and idx in finished_before:
1663
+ actual_num_batched_tokens -= 1
1664
+ continue
1665
+
1666
+ # Currently, skip == preempted sequences, so we need to skip
1667
+ # their log stats
1668
+ if skip and idx in skip:
1669
+ continue
1670
+
1671
+ group_was_prefill = idx < scheduler_outputs.num_prefill_groups
1672
+ seq_group = scheduled_seq_group.seq_group
1673
+
1674
+ # NOTE: a seq_group that completed all of its prefill tokens
1675
+ # in the last iteration will have seq_group.is_prefill() = False
1676
+ # with group_was_prefill = True
1677
+ if group_was_prefill:
1678
+ # Number of prompt tokens.
1679
+ num_prompt_tokens_iter += (
1680
+ scheduled_seq_group.token_chunk_size)
1681
+
1682
+ # If the seq_group just finished the prefill state
1683
+ # get TTFT.
1684
+ if not seq_group.is_prefill():
1685
+ latency = seq_group.get_last_token_latency()
1686
+ time_to_first_tokens_iter.append(latency)
1687
+
1688
+ # One generation token per finished prefill.
1689
+ num_generation_tokens_from_prefill_groups += (
1690
+ seq_group.num_seqs())
1691
+ else:
1692
+ # TPOTs.
1693
+ latency = seq_group.get_last_token_latency()
1694
+ time_per_output_tokens_iter.append(latency)
1695
+ if seq_group.state.current_step == 0:
1696
+ # For async_output_proc, the do_log_stats()
1697
+ # is called following init_multi_step(), which
1698
+ # sets the current_step to zero.
1699
+ actual_num_batched_tokens +=\
1700
+ seq_group.state.num_steps - 1
1701
+ else:
1702
+ actual_num_batched_tokens +=\
1703
+ seq_group.state.current_step - 1
1704
+
1705
+ # Because of chunked prefill, we can have a single sequence
1706
+ # group that does multiple prompt_runs. To prevent logging
1707
+ # the same metadata more than once per request, we standardize
1708
+ # on logging request level information for finished requests,
1709
+ # which can only happen once.
1710
+ if seq_group.is_finished():
1711
+ # Latency timings
1712
+ time_e2e_requests.append(now -
1713
+ seq_group.metrics.arrival_time)
1714
+ if (seq_group.metrics.first_scheduled_time is not None and
1715
+ seq_group.metrics.first_token_time is not None):
1716
+ time_queue_requests.append(
1717
+ seq_group.metrics.first_scheduled_time -
1718
+ seq_group.metrics.arrival_time)
1719
+ time_prefill_requests.append(
1720
+ seq_group.metrics.first_token_time -
1721
+ seq_group.metrics.first_scheduled_time)
1722
+ time_decode_requests.append(
1723
+ now - seq_group.metrics.first_token_time)
1724
+ time_inference_requests.append(
1725
+ now - seq_group.metrics.first_scheduled_time)
1726
+ if seq_group.metrics.time_in_queue is not None:
1727
+ time_in_queue_requests.append(
1728
+ seq_group.metrics.time_in_queue)
1729
+ if seq_group.metrics.model_forward_time is not None:
1730
+ model_forward_time_requests.append(
1731
+ seq_group.metrics.model_forward_time)
1732
+ if seq_group.metrics.model_execute_time is not None:
1733
+ model_execute_time_requests.append(
1734
+ seq_group.metrics.model_execute_time * 1000)
1735
+ # Metadata
1736
+ num_prompt_tokens_requests.append(
1737
+ len(seq_group.prompt_token_ids))
1738
+ num_generation_tokens_requests.extend([
1739
+ seq.get_output_len()
1740
+ for seq in seq_group.get_finished_seqs()
1741
+ ])
1742
+ max_num_generation_tokens_requests.append(
1743
+ max(seq.get_output_len()
1744
+ for seq in seq_group.get_seqs()))
1745
+ if seq_group.sampling_params is not None:
1746
+ n_requests.append(seq_group.sampling_params.n)
1747
+ max_tokens_requests.append(
1748
+ seq_group.sampling_params.max_tokens)
1749
+ finished_reason_requests.extend([
1750
+ SequenceStatus.get_finished_reason(seq.status)
1751
+ for seq in seq_group.get_finished_seqs()
1752
+ ])
1753
+
1754
+ # Number of generation tokens.
1755
+ # num_batched_tokens equals the number of prompt_tokens plus the
1756
+ # number of decode_tokens in a single iteration. So,
1757
+ # num_generation_tokens = num_batched_tokens - num_prompt_tokens
1758
+ # + num_generation_tokens_from_prefill_groups (since we generate
1759
+ # one token on prefills on iters where the prefill finishes).
1760
+ num_generation_tokens_iter = (
1761
+ actual_num_batched_tokens - num_prompt_tokens_iter +
1762
+ num_generation_tokens_from_prefill_groups)
1763
+ num_tokens_iter = (num_generation_tokens_iter +
1764
+ num_prompt_tokens_iter)
1765
+ # Spec decode, if enabled, emits specialized metrics from the worker in
1766
+ # sampler output.
1767
+ if model_output and isinstance(model_output[0], SamplerOutput) and (
1768
+ model_output[0].spec_decode_worker_metrics is not None):
1769
+ spec_decode_metrics = model_output[0].spec_decode_worker_metrics
1770
+ else:
1771
+ spec_decode_metrics = None
1772
+
1773
+ return Stats(
1774
+ now=now,
1775
+ # System stats
1776
+ # Scheduler State
1777
+ num_running_sys=num_running_sys,
1778
+ num_swapped_sys=num_swapped_sys,
1779
+ num_waiting_sys=num_waiting_sys,
1780
+ # KV Cache Usage in %
1781
+ gpu_cache_usage_sys=gpu_cache_usage_sys,
1782
+ cpu_cache_usage_sys=cpu_cache_usage_sys,
1783
+ # Prefix Cache Hit Rate
1784
+ cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
1785
+ gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
1786
+
1787
+ # Iteration stats
1788
+ num_prompt_tokens_iter=num_prompt_tokens_iter,
1789
+ num_generation_tokens_iter=num_generation_tokens_iter,
1790
+ num_tokens_iter=num_tokens_iter,
1791
+ time_to_first_tokens_iter=time_to_first_tokens_iter,
1792
+ time_per_output_tokens_iter=time_per_output_tokens_iter,
1793
+ spec_decode_metrics=spec_decode_metrics,
1794
+ num_preemption_iter=num_preemption_iter,
1795
+
1796
+ # Request stats
1797
+ # Latency
1798
+ time_e2e_requests=time_e2e_requests,
1799
+ time_queue_requests=time_queue_requests,
1800
+ time_inference_requests=time_inference_requests,
1801
+ time_prefill_requests=time_prefill_requests,
1802
+ time_decode_requests=time_decode_requests,
1803
+ time_in_queue_requests=time_in_queue_requests,
1804
+ model_forward_time_requests=model_forward_time_requests,
1805
+ model_execute_time_requests=model_execute_time_requests,
1806
+ # Metadata
1807
+ num_prompt_tokens_requests=num_prompt_tokens_requests,
1808
+ num_generation_tokens_requests=num_generation_tokens_requests,
1809
+ max_num_generation_tokens_requests=
1810
+ max_num_generation_tokens_requests,
1811
+ n_requests=n_requests,
1812
+ max_tokens_requests=max_tokens_requests,
1813
+ finished_reason_requests=finished_reason_requests,
1814
+ max_lora=str(max_lora_stat),
1815
+ waiting_lora_adapters=list(waiting_lora_adapters.keys()),
1816
+ running_lora_adapters=list(running_lora_adapters.keys()))
1817
+
1818
+ def add_lora(self, lora_request: LoRARequest) -> bool:
1819
+ return self.model_executor.add_lora(lora_request)
1820
+
1821
+ def remove_lora(self, lora_id: int) -> bool:
1822
+ return self.model_executor.remove_lora(lora_id)
1823
+
1824
+ def list_loras(self) -> Set[int]:
1825
+ return self.model_executor.list_loras()
1826
+
1827
+ def pin_lora(self, lora_id: int) -> bool:
1828
+ return self.model_executor.pin_lora(lora_id)
1829
+
1830
+ def add_prompt_adapter(
1831
+ self, prompt_adapter_request: PromptAdapterRequest) -> bool:
1832
+ return self.model_executor.add_prompt_adapter(prompt_adapter_request)
1833
+
1834
+ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
1835
+ return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
1836
+
1837
+ def list_prompt_adapters(self) -> List[int]:
1838
+ return self.model_executor.list_prompt_adapters()
1839
+
1840
+ def start_profile(self) -> None:
1841
+ self.model_executor.start_profile()
1842
+
1843
+ def stop_profile(self) -> None:
1844
+ self.model_executor.stop_profile()
1845
+
1846
+ def sleep(self, level: int = 1) -> None:
1847
+ assert self.vllm_config.model_config.enable_sleep_mode, (
1848
+ "Sleep mode is not enabled in the model config")
1849
+ self.model_executor.sleep(level=level)
1850
+
1851
+ def wake_up(self) -> None:
1852
+ assert self.vllm_config.model_config.enable_sleep_mode, (
1853
+ "Sleep mode is not enabled in the model config")
1854
+ self.model_executor.wake_up()
1855
+
1856
+ def check_health(self) -> None:
1857
+ if self.tokenizer:
1858
+ self.tokenizer.check_health()
1859
+ self.model_executor.check_health()
1860
+
1861
+ def is_tracing_enabled(self) -> bool:
1862
+ return self.tracer is not None
1863
+
1864
+ def do_tracing(self,
1865
+ scheduler_outputs: SchedulerOutputs,
1866
+ finished_before: Optional[List[int]] = None) -> None:
1867
+ if self.tracer is None:
1868
+ return
1869
+
1870
+ for idx, scheduled_seq_group in enumerate(
1871
+ scheduler_outputs.scheduled_seq_groups):
1872
+ # Skip double tracing when using async output proc
1873
+ if finished_before and idx in finished_before:
1874
+ continue
1875
+
1876
+ seq_group = scheduled_seq_group.seq_group
1877
+ if seq_group.is_finished():
1878
+ self.create_trace_span(seq_group)
1879
+
1880
+ def create_trace_span(self, seq_group: SequenceGroup) -> None:
1881
+ if self.tracer is None or seq_group.sampling_params is None:
1882
+ return
1883
+ arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
1884
+
1885
+ trace_context = extract_trace_context(seq_group.trace_headers)
1886
+
1887
+ with self.tracer.start_as_current_span(
1888
+ "llm_request",
1889
+ kind=SpanKind.SERVER,
1890
+ context=trace_context,
1891
+ start_time=arrival_time_nano_seconds) as seq_span:
1892
+ metrics = seq_group.metrics
1893
+ ttft = metrics.first_token_time - metrics.arrival_time
1894
+ e2e_time = metrics.finished_time - metrics.arrival_time
1895
+ seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
1896
+ self.model_config.model)
1897
+ seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
1898
+ seq_group.request_id)
1899
+ seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
1900
+ seq_group.sampling_params.temperature)
1901
+ seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
1902
+ seq_group.sampling_params.top_p)
1903
+ seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
1904
+ seq_group.sampling_params.max_tokens)
1905
+ seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
1906
+ seq_group.sampling_params.n)
1907
+ seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
1908
+ seq_group.num_seqs())
1909
+ seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
1910
+ len(seq_group.prompt_token_ids))
1911
+ seq_span.set_attribute(
1912
+ SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
1913
+ sum([
1914
+ seq.get_output_len()
1915
+ for seq in seq_group.get_finished_seqs()
1916
+ ]))
1917
+ seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
1918
+ metrics.time_in_queue)
1919
+ seq_span.set_attribute(
1920
+ SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
1921
+ seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
1922
+ if metrics.scheduler_time is not None:
1923
+ seq_span.set_attribute(
1924
+ SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
1925
+ metrics.scheduler_time)
1926
+ if metrics.model_forward_time is not None:
1927
+ seq_span.set_attribute(
1928
+ SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
1929
+ metrics.model_forward_time / 1000.0)
1930
+ if metrics.model_execute_time is not None:
1931
+ seq_span.set_attribute(
1932
+ SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
1933
+ metrics.model_execute_time)
1934
+
1935
+ def _validate_model_inputs(self, inputs: ProcessorInputs,
1936
+ lora_request: Optional[LoRARequest]):
1937
+ if is_encoder_decoder_inputs(inputs):
1938
+ # For encoder-decoder multimodal models, the max_prompt_len
1939
+ # restricts the decoder prompt length
1940
+ prompt_inputs = inputs["decoder" if self.model_config.
1941
+ is_multimodal_model else "encoder"]
1942
+ else:
1943
+ prompt_inputs = inputs
1944
+
1945
+ prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
1946
+
1947
+ if prompt_ids is None or len(prompt_ids) == 0:
1948
+ raise ValueError("Prompt cannot be empty")
1949
+
1950
+ if self.model_config.is_multimodal_model:
1951
+ max_prompt_len = self.model_config.max_model_len
1952
+
1953
+ if len(prompt_ids) > max_prompt_len:
1954
+ raise ValueError(
1955
+ f"The prompt (total length {len(prompt_ids)}) is too long "
1956
+ f"to fit into the model (context length {max_prompt_len}). "
1957
+ "Make sure that `max_model_len` is no smaller than the "
1958
+ "number of text tokens plus multimodal tokens. For image "
1959
+ "inputs, the number of image tokens depends on the number "
1960
+ "of images, and possibly their aspect ratios as well.")
1961
+
1962
+ # TODO: Find out how many placeholder tokens are there so we can
1963
+ # check that chunked prefill does not truncate them
1964
+ # max_batch_len = self.scheduler_config.max_num_batched_tokens
1965
+
1966
+ def _build_logits_processors(
1967
+ self, sampling_params: SamplingParams,
1968
+ lora_request: Optional[LoRARequest]) -> SamplingParams:
1969
+ """Constructs logits processors based on the guided_decoding,
1970
+ logits_bias, and allowed_token_ids fields in sampling_params. Deletes
1971
+ those fields and adds the constructed logits processors to the
1972
+ logits_processors field. Returns the modified sampling params."""
1973
+
1974
+ logits_processors = []
1975
+
1976
+ if sampling_params.guided_decoding is not None:
1977
+ # Defensively copy sampling params since guided decoding logits
1978
+ # processors can have different state for each request
1979
+ sampling_params = copy.copy(sampling_params)
1980
+ guided_decoding = sampling_params.guided_decoding
1981
+
1982
+ logger.debug(
1983
+ "Building guided decoding logits processor in "
1984
+ "LLMEngine. Params: %s", guided_decoding)
1985
+
1986
+ tokenizer = self.get_tokenizer(lora_request=lora_request)
1987
+ guided_decoding.backend = guided_decoding.backend or \
1988
+ self.decoding_config.guided_decoding_backend
1989
+
1990
+ processor = get_local_guided_decoding_logits_processor(
1991
+ guided_params=guided_decoding,
1992
+ tokenizer=tokenizer,
1993
+ model_config=self.model_config)
1994
+ if processor:
1995
+ logits_processors.append(processor)
1996
+
1997
+ # Unset so this doesn't get passed down to the model
1998
+ sampling_params.guided_decoding = None
1999
+
2000
+ if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
2001
+ tokenizer = self.get_tokenizer(lora_request=lora_request)
2002
+
2003
+ processors = get_openai_logits_processors(
2004
+ logit_bias=sampling_params.logit_bias,
2005
+ allowed_token_ids=sampling_params.allowed_token_ids,
2006
+ tokenizer=tokenizer)
2007
+ logits_processors.extend(processors)
2008
+
2009
+ # Unset so these don't get passed down to the model
2010
+ sampling_params.logit_bias = None
2011
+ sampling_params.allowed_token_ids = None
2012
+
2013
+ if len(sampling_params.bad_words) > 0:
2014
+ tokenizer = self.get_tokenizer(lora_request)
2015
+ processors = get_bad_words_logits_processors(
2016
+ bad_words=sampling_params.bad_words, tokenizer=tokenizer)
2017
+ logits_processors.extend(processors)
2018
+
2019
+ if logits_processors:
2020
+ if sampling_params.logits_processors is None:
2021
+ sampling_params.logits_processors = logits_processors
2022
+ else:
2023
+ sampling_params.logits_processors.extend(logits_processors)
2024
+
2025
+ return sampling_params
.venv/lib/python3.11/site-packages/vllm/engine/metrics.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import time
4
+ from typing import TYPE_CHECKING
5
+ from typing import Counter as CollectionsCounter
6
+ from typing import Dict, List, Optional, Type, Union, cast
7
+
8
+ import numpy as np
9
+ import prometheus_client
10
+
11
+ from vllm.config import VllmConfig
12
+ from vllm.engine.metrics_types import (StatLoggerBase, Stats,
13
+ SupportsMetricsInfo)
14
+ from vllm.executor.ray_utils import ray
15
+ from vllm.logger import init_logger
16
+
17
+ if ray is not None:
18
+ from ray.util import metrics as ray_metrics
19
+ else:
20
+ ray_metrics = None
21
+
22
+ if TYPE_CHECKING:
23
+ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
24
+
25
+ logger = init_logger(__name__)
26
+
27
+ prometheus_client.disable_created_metrics()
28
+
29
+ # The begin-* and end* here are used by the documentation generator
30
+ # to extract the metrics definitions.
31
+
32
+
33
+ # begin-metrics-definitions
34
+ class Metrics:
35
+ """
36
+ vLLM uses a multiprocessing-based frontend for the OpenAI server.
37
+ This means that we need to run prometheus_client in multiprocessing mode
38
+ See https://prometheus.github.io/client_python/multiprocess/ for more
39
+ details on limitations.
40
+ """
41
+
42
+ labelname_finish_reason = "finished_reason"
43
+ labelname_waiting_lora_adapters = "waiting_lora_adapters"
44
+ labelname_running_lora_adapters = "running_lora_adapters"
45
+ labelname_max_lora = "max_lora"
46
+ _gauge_cls = prometheus_client.Gauge
47
+ _counter_cls = prometheus_client.Counter
48
+ _histogram_cls = prometheus_client.Histogram
49
+
50
+ def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
51
+ # Unregister any existing vLLM collectors (for CI/CD)
52
+ self._unregister_vllm_metrics()
53
+
54
+ max_model_len = vllm_config.model_config.max_model_len
55
+
56
+ # System stats
57
+ # Scheduler State
58
+ self.gauge_scheduler_running = self._gauge_cls(
59
+ name="vllm:num_requests_running",
60
+ documentation="Number of requests currently running on GPU.",
61
+ labelnames=labelnames,
62
+ multiprocess_mode="sum")
63
+ self.gauge_scheduler_waiting = self._gauge_cls(
64
+ name="vllm:num_requests_waiting",
65
+ documentation="Number of requests waiting to be processed.",
66
+ labelnames=labelnames,
67
+ multiprocess_mode="sum")
68
+ self.gauge_lora_info = self._gauge_cls(
69
+ name="vllm:lora_requests_info",
70
+ documentation="Running stats on lora requests.",
71
+ labelnames=[
72
+ self.labelname_running_lora_adapters,
73
+ self.labelname_max_lora,
74
+ self.labelname_waiting_lora_adapters,
75
+ ],
76
+ multiprocess_mode="livemostrecent",
77
+ )
78
+ self.gauge_scheduler_swapped = self._gauge_cls(
79
+ name="vllm:num_requests_swapped",
80
+ documentation="Number of requests swapped to CPU.",
81
+ labelnames=labelnames,
82
+ multiprocess_mode="sum")
83
+ # KV Cache Usage in %
84
+ self.gauge_gpu_cache_usage = self._gauge_cls(
85
+ name="vllm:gpu_cache_usage_perc",
86
+ documentation="GPU KV-cache usage. 1 means 100 percent usage.",
87
+ labelnames=labelnames,
88
+ multiprocess_mode="sum")
89
+ self.gauge_cpu_cache_usage = self._gauge_cls(
90
+ name="vllm:cpu_cache_usage_perc",
91
+ documentation="CPU KV-cache usage. 1 means 100 percent usage.",
92
+ labelnames=labelnames,
93
+ multiprocess_mode="sum")
94
+ # Prefix caching block hit rate
95
+ self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
96
+ name="vllm:cpu_prefix_cache_hit_rate",
97
+ documentation="CPU prefix cache block hit rate.",
98
+ labelnames=labelnames,
99
+ multiprocess_mode="sum")
100
+ self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
101
+ name="vllm:gpu_prefix_cache_hit_rate",
102
+ documentation="GPU prefix cache block hit rate.",
103
+ labelnames=labelnames,
104
+ multiprocess_mode="sum")
105
+
106
+ # Iteration stats
107
+ self.counter_num_preemption = self._counter_cls(
108
+ name="vllm:num_preemptions_total",
109
+ documentation="Cumulative number of preemption from the engine.",
110
+ labelnames=labelnames)
111
+ self.counter_prompt_tokens = self._counter_cls(
112
+ name="vllm:prompt_tokens_total",
113
+ documentation="Number of prefill tokens processed.",
114
+ labelnames=labelnames)
115
+ self.counter_generation_tokens = self._counter_cls(
116
+ name="vllm:generation_tokens_total",
117
+ documentation="Number of generation tokens processed.",
118
+ labelnames=labelnames)
119
+ self.counter_tokens = self._counter_cls(
120
+ name="vllm:tokens_total",
121
+ documentation="Number of prefill plus generation tokens processed.",
122
+ labelnames=labelnames)
123
+ buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
124
+ if not vllm_config.model_config.enforce_eager:
125
+ buckets = vllm_config.compilation_config.\
126
+ cudagraph_capture_sizes.copy()
127
+ buckets.sort()
128
+ self.histogram_iteration_tokens = self._histogram_cls(
129
+ name="vllm:iteration_tokens_total",
130
+ documentation="Histogram of number of tokens per engine_step.",
131
+ labelnames=labelnames,
132
+ buckets=buckets)
133
+ self.histogram_time_to_first_token = self._histogram_cls(
134
+ name="vllm:time_to_first_token_seconds",
135
+ documentation="Histogram of time to first token in seconds.",
136
+ labelnames=labelnames,
137
+ buckets=[
138
+ 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
139
+ 0.75, 1.0, 2.5, 5.0, 7.5, 10.0
140
+ ])
141
+ self.histogram_time_per_output_token = self._histogram_cls(
142
+ name="vllm:time_per_output_token_seconds",
143
+ documentation="Histogram of time per output token in seconds.",
144
+ labelnames=labelnames,
145
+ buckets=[
146
+ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
147
+ 1.0, 2.5
148
+ ])
149
+
150
+ # Request stats
151
+ # Latency
152
+ request_latency_buckets = [
153
+ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
154
+ 40.0, 50.0, 60.0
155
+ ]
156
+ self.histogram_e2e_time_request = self._histogram_cls(
157
+ name="vllm:e2e_request_latency_seconds",
158
+ documentation="Histogram of end to end request latency in seconds.",
159
+ labelnames=labelnames,
160
+ buckets=request_latency_buckets)
161
+ self.histogram_queue_time_request = self._histogram_cls(
162
+ name="vllm:request_queue_time_seconds",
163
+ documentation=
164
+ "Histogram of time spent in WAITING phase for request.",
165
+ labelnames=labelnames,
166
+ buckets=request_latency_buckets)
167
+ self.histogram_inference_time_request = self._histogram_cls(
168
+ name="vllm:request_inference_time_seconds",
169
+ documentation=
170
+ "Histogram of time spent in RUNNING phase for request.",
171
+ labelnames=labelnames,
172
+ buckets=request_latency_buckets)
173
+ self.histogram_prefill_time_request = self._histogram_cls(
174
+ name="vllm:request_prefill_time_seconds",
175
+ documentation=
176
+ "Histogram of time spent in PREFILL phase for request.",
177
+ labelnames=labelnames,
178
+ buckets=request_latency_buckets)
179
+ self.histogram_decode_time_request = self._histogram_cls(
180
+ name="vllm:request_decode_time_seconds",
181
+ documentation=
182
+ "Histogram of time spent in DECODE phase for request.",
183
+ labelnames=labelnames,
184
+ buckets=request_latency_buckets)
185
+ self.histogram_time_in_queue_request = self._histogram_cls(
186
+ name="vllm:time_in_queue_requests",
187
+ documentation=
188
+ "Histogram of time the request spent in the queue in seconds.",
189
+ labelnames=labelnames,
190
+ buckets=request_latency_buckets)
191
+ self.histogram_model_forward_time_request = self._histogram_cls(
192
+ name="vllm:model_forward_time_milliseconds",
193
+ documentation=
194
+ "Histogram of time spent in the model forward pass in ms.",
195
+ labelnames=labelnames,
196
+ buckets=build_1_2_3_5_8_buckets(3000))
197
+ self.histogram_model_execute_time_request = self._histogram_cls(
198
+ name="vllm:model_execute_time_milliseconds",
199
+ documentation=
200
+ "Histogram of time spent in the model execute function in ms.",
201
+ labelnames=labelnames,
202
+ buckets=build_1_2_3_5_8_buckets(3000))
203
+ # Metadata
204
+ self.histogram_num_prompt_tokens_request = self._histogram_cls(
205
+ name="vllm:request_prompt_tokens",
206
+ documentation="Number of prefill tokens processed.",
207
+ labelnames=labelnames,
208
+ buckets=build_1_2_5_buckets(max_model_len),
209
+ )
210
+ self.histogram_num_generation_tokens_request = \
211
+ self._histogram_cls(
212
+ name="vllm:request_generation_tokens",
213
+ documentation="Number of generation tokens processed.",
214
+ labelnames=labelnames,
215
+ buckets=build_1_2_5_buckets(max_model_len),
216
+ )
217
+ self.histogram_max_num_generation_tokens_request = self._histogram_cls(
218
+ name="vllm:request_max_num_generation_tokens",
219
+ documentation=
220
+ "Histogram of maximum number of requested generation tokens.",
221
+ labelnames=labelnames,
222
+ buckets=build_1_2_5_buckets(max_model_len))
223
+ self.histogram_n_request = self._histogram_cls(
224
+ name="vllm:request_params_n",
225
+ documentation="Histogram of the n request parameter.",
226
+ labelnames=labelnames,
227
+ buckets=[1, 2, 5, 10, 20],
228
+ )
229
+ self.histogram_max_tokens_request = self._histogram_cls(
230
+ name="vllm:request_params_max_tokens",
231
+ documentation="Histogram of the max_tokens request parameter.",
232
+ labelnames=labelnames,
233
+ buckets=build_1_2_5_buckets(max_model_len),
234
+ )
235
+ self.counter_request_success = self._counter_cls(
236
+ name="vllm:request_success_total",
237
+ documentation="Count of successfully processed requests.",
238
+ labelnames=labelnames + [Metrics.labelname_finish_reason])
239
+
240
+ # Speculatie decoding stats
241
+ self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
242
+ name="vllm:spec_decode_draft_acceptance_rate",
243
+ documentation="Speulative token acceptance rate.",
244
+ labelnames=labelnames,
245
+ multiprocess_mode="sum")
246
+ self.gauge_spec_decode_efficiency = self._gauge_cls(
247
+ name="vllm:spec_decode_efficiency",
248
+ documentation="Speculative decoding system efficiency.",
249
+ labelnames=labelnames,
250
+ multiprocess_mode="sum")
251
+ self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
252
+ name="vllm:spec_decode_num_accepted_tokens_total",
253
+ documentation="Number of accepted tokens.",
254
+ labelnames=labelnames))
255
+ self.counter_spec_decode_num_draft_tokens = self._counter_cls(
256
+ name="vllm:spec_decode_num_draft_tokens_total",
257
+ documentation="Number of draft tokens.",
258
+ labelnames=labelnames)
259
+ self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
260
+ name="vllm:spec_decode_num_emitted_tokens_total",
261
+ documentation="Number of emitted tokens.",
262
+ labelnames=labelnames))
263
+
264
+
265
+ # end-metrics-definitions
266
+
267
+ def _unregister_vllm_metrics(self) -> None:
268
+ for collector in list(prometheus_client.REGISTRY._collector_to_names):
269
+ if hasattr(collector, "_name") and "vllm" in collector._name:
270
+ prometheus_client.REGISTRY.unregister(collector)
271
+
272
+
273
+ class _RayGaugeWrapper:
274
+ """Wraps around ray.util.metrics.Gauge to provide same API as
275
+ prometheus_client.Gauge"""
276
+
277
+ def __init__(self,
278
+ name: str,
279
+ documentation: str = "",
280
+ labelnames: Optional[List[str]] = None,
281
+ multiprocess_mode: str = ""):
282
+ del multiprocess_mode
283
+ labelnames_tuple = tuple(labelnames) if labelnames else None
284
+ self._gauge = ray_metrics.Gauge(name=name,
285
+ description=documentation,
286
+ tag_keys=labelnames_tuple)
287
+
288
+ def labels(self, **labels):
289
+ self._gauge.set_default_tags(labels)
290
+ return self
291
+
292
+ def set(self, value: Union[int, float]):
293
+ return self._gauge.set(value)
294
+
295
+ def set_to_current_time(self):
296
+ # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
297
+ return self._gauge.set(time.time())
298
+
299
+
300
+ class _RayCounterWrapper:
301
+ """Wraps around ray.util.metrics.Counter to provide same API as
302
+ prometheus_client.Counter"""
303
+
304
+ def __init__(self,
305
+ name: str,
306
+ documentation: str = "",
307
+ labelnames: Optional[List[str]] = None):
308
+ labelnames_tuple = tuple(labelnames) if labelnames else None
309
+ self._counter = ray_metrics.Counter(name=name,
310
+ description=documentation,
311
+ tag_keys=labelnames_tuple)
312
+
313
+ def labels(self, **labels):
314
+ self._counter.set_default_tags(labels)
315
+ return self
316
+
317
+ def inc(self, value: Union[int, float] = 1.0):
318
+ if value == 0:
319
+ return
320
+ return self._counter.inc(value)
321
+
322
+
323
+ class _RayHistogramWrapper:
324
+ """Wraps around ray.util.metrics.Histogram to provide same API as
325
+ prometheus_client.Histogram"""
326
+
327
+ def __init__(self,
328
+ name: str,
329
+ documentation: str = "",
330
+ labelnames: Optional[List[str]] = None,
331
+ buckets: Optional[List[float]] = None):
332
+ labelnames_tuple = tuple(labelnames) if labelnames else None
333
+ boundaries = buckets if buckets else []
334
+ self._histogram = ray_metrics.Histogram(name=name,
335
+ description=documentation,
336
+ tag_keys=labelnames_tuple,
337
+ boundaries=boundaries)
338
+
339
+ def labels(self, **labels):
340
+ self._histogram.set_default_tags(labels)
341
+ return self
342
+
343
+ def observe(self, value: Union[int, float]):
344
+ return self._histogram.observe(value)
345
+
346
+
347
+ class RayMetrics(Metrics):
348
+ """
349
+ RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
350
+ Provides the same metrics as Metrics but uses Ray's util.metrics library.
351
+ """
352
+ _gauge_cls: Type[prometheus_client.Gauge] = cast(
353
+ Type[prometheus_client.Gauge], _RayGaugeWrapper)
354
+ _counter_cls: Type[prometheus_client.Counter] = cast(
355
+ Type[prometheus_client.Counter], _RayCounterWrapper)
356
+ _histogram_cls: Type[prometheus_client.Histogram] = cast(
357
+ Type[prometheus_client.Histogram], _RayHistogramWrapper)
358
+
359
+ def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
360
+ if ray_metrics is None:
361
+ raise ImportError("RayMetrics requires Ray to be installed.")
362
+ super().__init__(labelnames, vllm_config)
363
+
364
+ def _unregister_vllm_metrics(self) -> None:
365
+ # No-op on purpose
366
+ pass
367
+
368
+
369
+ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
370
+ """
371
+ Builds a list of buckets with increasing powers of 10 multiplied by
372
+ mantissa values until the value exceeds the specified maximum.
373
+
374
+ """
375
+ exponent = 0
376
+ buckets: List[int] = []
377
+ while True:
378
+ for m in mantissa_lst:
379
+ value = m * 10**exponent
380
+ if value <= max_value:
381
+ buckets.append(value)
382
+ else:
383
+ return buckets
384
+ exponent += 1
385
+
386
+
387
+ def build_1_2_5_buckets(max_value: int) -> List[int]:
388
+ """
389
+ Example:
390
+ >>> build_1_2_5_buckets(100)
391
+ [1, 2, 5, 10, 20, 50, 100]
392
+ """
393
+ return build_buckets([1, 2, 5], max_value)
394
+
395
+
396
+ def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
397
+ """
398
+ Example:
399
+ >>> build_1_2_3_5_8_buckets(100)
400
+ [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100]
401
+ """
402
+ return build_buckets([1, 2, 3, 5, 8], max_value)
403
+
404
+
405
+ def local_interval_elapsed(now: float, last_log: float,
406
+ local_interval: float) -> bool:
407
+ elapsed_time = now - last_log
408
+ return elapsed_time > local_interval
409
+
410
+
411
+ def get_throughput(tracked_stats: List[int], now: float,
412
+ last_log: float) -> float:
413
+ return float(np.sum(tracked_stats) / (now - last_log))
414
+
415
+
416
+ class LoggingStatLogger(StatLoggerBase):
417
+ """LoggingStatLogger is used in LLMEngine to log to Stdout."""
418
+
419
+ def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
420
+ super().__init__(local_interval, vllm_config)
421
+ self.last_prompt_throughput: Optional[float] = None
422
+ self.last_generation_throughput: Optional[float] = None
423
+
424
+ def log(self, stats: Stats) -> None:
425
+ """Called by LLMEngine.
426
+ Logs to Stdout every self.local_interval seconds."""
427
+
428
+ # Save tracked stats for token counters.
429
+ self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
430
+ self.num_generation_tokens.append(stats.num_generation_tokens_iter)
431
+
432
+ # Update spec decode metrics
433
+ self.maybe_update_spec_decode_metrics(stats)
434
+
435
+ # Log locally every local_interval seconds.
436
+ if local_interval_elapsed(stats.now, self.last_local_log,
437
+ self.local_interval):
438
+ # Compute summary metrics for tracked stats (and log them
439
+ # to promethus if applicable).
440
+ prompt_throughput = get_throughput(self.num_prompt_tokens,
441
+ now=stats.now,
442
+ last_log=self.last_local_log)
443
+ generation_throughput = get_throughput(
444
+ self.num_generation_tokens,
445
+ now=stats.now,
446
+ last_log=self.last_local_log)
447
+
448
+ log_fn = logger.info
449
+ if not any((prompt_throughput, generation_throughput,
450
+ self.last_prompt_throughput,
451
+ self.last_generation_throughput)):
452
+ # Avoid log noise on an idle production system
453
+ log_fn = logger.debug
454
+
455
+ log_fn(
456
+ "Avg prompt throughput: %.1f tokens/s, "
457
+ "Avg generation throughput: %.1f tokens/s, "
458
+ "Running: %d reqs, Swapped: %d reqs, "
459
+ "Pending: %d reqs, GPU KV cache usage: %.1f%%, "
460
+ "CPU KV cache usage: %.1f%%.",
461
+ prompt_throughput,
462
+ generation_throughput,
463
+ stats.num_running_sys,
464
+ stats.num_swapped_sys,
465
+ stats.num_waiting_sys,
466
+ stats.gpu_cache_usage_sys * 100,
467
+ stats.cpu_cache_usage_sys * 100,
468
+ )
469
+ if (stats.cpu_prefix_cache_hit_rate >= 0
470
+ or stats.gpu_prefix_cache_hit_rate >= 0):
471
+ log_fn(
472
+ "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
473
+ stats.gpu_prefix_cache_hit_rate * 100,
474
+ stats.cpu_prefix_cache_hit_rate * 100,
475
+ )
476
+ if self.spec_decode_metrics is not None:
477
+ log_fn(
478
+ self._format_spec_decode_metrics_str(
479
+ self.spec_decode_metrics))
480
+
481
+ self._reset(stats, prompt_throughput, generation_throughput)
482
+
483
+ def _reset(self, stats, prompt_throughput, generation_throughput) -> None:
484
+ # Reset tracked stats for next interval.
485
+ self.num_prompt_tokens = []
486
+ self.num_generation_tokens = []
487
+ self.last_local_log = stats.now
488
+ self.spec_decode_metrics = None
489
+ self.last_prompt_throughput = prompt_throughput
490
+ self.last_generation_throughput = generation_throughput
491
+
492
+ def _format_spec_decode_metrics_str(
493
+ self, metrics: "SpecDecodeWorkerMetrics") -> str:
494
+
495
+ return ("Speculative metrics: "
496
+ f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
497
+ f"System efficiency: {metrics.system_efficiency:.3f}, "
498
+ f"Number of speculative tokens: {metrics.num_spec_tokens}, "
499
+ f"Number of accepted tokens: {metrics.accepted_tokens}, "
500
+ f"Number of draft tokens: {metrics.draft_tokens}, "
501
+ f"Number of emitted tokens: {metrics.emitted_tokens}.")
502
+
503
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
504
+ raise NotImplementedError
505
+
506
+
507
+ class PrometheusStatLogger(StatLoggerBase):
508
+ """PrometheusStatLogger is used LLMEngine to log to Promethus."""
509
+ _metrics_cls = Metrics
510
+ _gauge_cls = prometheus_client.Gauge
511
+
512
+ def __init__(self, local_interval: float, labels: Dict[str, str],
513
+ vllm_config: VllmConfig) -> None:
514
+ super().__init__(local_interval, vllm_config)
515
+ # Prometheus metrics
516
+ self.labels = labels
517
+ self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
518
+ vllm_config=vllm_config)
519
+
520
+ def _log_gauge(self, gauge, data: Union[int, float]) -> None:
521
+ # Convenience function for logging to gauge.
522
+ gauge.labels(**self.labels).set(data)
523
+
524
+ def _log_counter(self, counter, data: Union[int, float]) -> None:
525
+ # Convenience function for logging to counter.
526
+ # Prevent ValueError from negative increment
527
+ if data < 0:
528
+ logger.warning("Skipping negative increment of %g to %s", data,
529
+ counter)
530
+ return
531
+ counter.labels(**self.labels).inc(data)
532
+
533
+ def _log_counter_labels(self, counter, data: CollectionsCounter,
534
+ label_key: str) -> None:
535
+ # Convenience function for collection counter of labels.
536
+ for label, count in data.items():
537
+ counter.labels(**{**self.labels, label_key: label}).inc(count)
538
+
539
+ def _log_histogram(self, histogram, data: Union[List[int],
540
+ List[float]]) -> None:
541
+ # Convenience function for logging list to histogram.
542
+ for datum in data:
543
+ histogram.labels(**self.labels).observe(datum)
544
+
545
+ def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
546
+ gauge.labels(**data).set_to_current_time()
547
+
548
+ def _log_prometheus(self, stats: Stats) -> None:
549
+ # System state data
550
+ self._log_gauge(self.metrics.gauge_scheduler_running,
551
+ stats.num_running_sys)
552
+ self._log_gauge(self.metrics.gauge_scheduler_swapped,
553
+ stats.num_swapped_sys)
554
+ self._log_gauge(self.metrics.gauge_scheduler_waiting,
555
+ stats.num_waiting_sys)
556
+ self._log_gauge(self.metrics.gauge_gpu_cache_usage,
557
+ stats.gpu_cache_usage_sys)
558
+ self._log_gauge(self.metrics.gauge_cpu_cache_usage,
559
+ stats.cpu_cache_usage_sys)
560
+ self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
561
+ stats.cpu_prefix_cache_hit_rate)
562
+ self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
563
+ stats.gpu_prefix_cache_hit_rate)
564
+ # Including max-lora in metric, in future this property of lora
565
+ # config maybe extended to be dynamic.
566
+ lora_info = {
567
+ self.metrics.labelname_running_lora_adapters:
568
+ ",".join(stats.running_lora_adapters),
569
+ self.metrics.labelname_waiting_lora_adapters:
570
+ ",".join(stats.waiting_lora_adapters),
571
+ self.metrics.labelname_max_lora:
572
+ stats.max_lora,
573
+ }
574
+ self._log_gauge_string(self.metrics.gauge_lora_info, lora_info)
575
+ # Iteration level data
576
+ self._log_counter(self.metrics.counter_num_preemption,
577
+ stats.num_preemption_iter)
578
+ self._log_counter(self.metrics.counter_prompt_tokens,
579
+ stats.num_prompt_tokens_iter)
580
+ self._log_counter(self.metrics.counter_generation_tokens,
581
+ stats.num_generation_tokens_iter)
582
+ self._log_histogram(self.metrics.histogram_iteration_tokens,
583
+ [stats.num_tokens_iter])
584
+ self._log_histogram(self.metrics.histogram_time_to_first_token,
585
+ stats.time_to_first_tokens_iter)
586
+ self._log_histogram(self.metrics.histogram_time_per_output_token,
587
+ stats.time_per_output_tokens_iter)
588
+
589
+ # Request level data
590
+ # Latency
591
+ self._log_histogram(self.metrics.histogram_e2e_time_request,
592
+ stats.time_e2e_requests)
593
+ self._log_histogram(self.metrics.histogram_queue_time_request,
594
+ stats.time_queue_requests)
595
+ self._log_histogram(self.metrics.histogram_inference_time_request,
596
+ stats.time_inference_requests)
597
+ self._log_histogram(self.metrics.histogram_prefill_time_request,
598
+ stats.time_prefill_requests)
599
+ self._log_histogram(self.metrics.histogram_decode_time_request,
600
+ stats.time_decode_requests)
601
+ self._log_histogram(self.metrics.histogram_time_in_queue_request,
602
+ stats.time_in_queue_requests)
603
+ self._log_histogram(self.metrics.histogram_model_forward_time_request,
604
+ stats.model_forward_time_requests)
605
+ self._log_histogram(self.metrics.histogram_model_execute_time_request,
606
+ stats.model_execute_time_requests)
607
+ # Metadata
608
+ finished_reason_counter = CollectionsCounter(
609
+ stats.finished_reason_requests)
610
+ self._log_counter_labels(self.metrics.counter_request_success,
611
+ finished_reason_counter,
612
+ Metrics.labelname_finish_reason)
613
+ self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
614
+ stats.num_prompt_tokens_requests)
615
+ self._log_histogram(
616
+ self.metrics.histogram_num_generation_tokens_request,
617
+ stats.num_generation_tokens_requests)
618
+ self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
619
+ self._log_histogram(
620
+ self.metrics.histogram_max_num_generation_tokens_request,
621
+ stats.max_num_generation_tokens_requests)
622
+ self._log_histogram(self.metrics.histogram_max_tokens_request,
623
+ stats.max_tokens_requests)
624
+
625
+ def log(self, stats: Stats):
626
+ """Logs to prometheus and tracked stats every iteration."""
627
+ # Log to prometheus.
628
+ self._log_prometheus(stats)
629
+
630
+ # Save tracked stats for token counters.
631
+ self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
632
+ self.num_generation_tokens.append(stats.num_generation_tokens_iter)
633
+
634
+ # Update spec decode metrics
635
+ self.maybe_update_spec_decode_metrics(stats)
636
+
637
+ # Log locally every local_interval seconds.
638
+ if local_interval_elapsed(stats.now, self.last_local_log,
639
+ self.local_interval):
640
+ if self.spec_decode_metrics is not None:
641
+ self._log_gauge(
642
+ self.metrics.gauge_spec_decode_draft_acceptance_rate,
643
+ self.spec_decode_metrics.draft_acceptance_rate)
644
+ self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
645
+ self.spec_decode_metrics.system_efficiency)
646
+ self._log_counter(
647
+ self.metrics.counter_spec_decode_num_accepted_tokens,
648
+ self.spec_decode_metrics.accepted_tokens)
649
+ self._log_counter(
650
+ self.metrics.counter_spec_decode_num_draft_tokens,
651
+ self.spec_decode_metrics.draft_tokens)
652
+ self._log_counter(
653
+ self.metrics.counter_spec_decode_num_emitted_tokens,
654
+ self.spec_decode_metrics.emitted_tokens)
655
+
656
+ # Reset tracked stats for next interval.
657
+ self.num_prompt_tokens = []
658
+ self.num_generation_tokens = []
659
+ self.last_local_log = stats.now
660
+ self.spec_decode_metrics = None
661
+
662
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
663
+ # Info type metrics are syntactic sugar for a gauge permanently set to 1
664
+ # Since prometheus multiprocessing mode does not support Info, emulate
665
+ # info here with a gauge.
666
+ if type == "cache_config":
667
+ metrics_info = obj.metrics_info()
668
+ info_gauge = self._gauge_cls(
669
+ name="vllm:cache_config_info",
670
+ documentation="Information of the LLMEngine CacheConfig",
671
+ labelnames=metrics_info.keys(),
672
+ multiprocess_mode="mostrecent")
673
+ info_gauge.labels(**metrics_info).set(1)
674
+
675
+
676
+ class RayPrometheusStatLogger(PrometheusStatLogger):
677
+ """RayPrometheusStatLogger uses Ray metrics instead."""
678
+ _metrics_cls = RayMetrics
679
+
680
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
681
+ return None
.venv/lib/python3.11/site-packages/vllm/engine/metrics_types.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ These types are defined in this file to avoid importing vllm.engine.metrics
4
+ and therefore importing prometheus_client.
5
+
6
+ This is required due to usage of Prometheus multiprocess mode to enable
7
+ metrics after splitting out the uvicorn process from the engine process.
8
+
9
+ Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
10
+ before prometheus_client is imported. Typically, this is done by setting
11
+ the env variable before launch, but since we are a library, we need to
12
+ do this in Python code and lazily import prometheus_client.
13
+ """
14
+
15
+ import time
16
+ from abc import ABC, abstractmethod
17
+ from dataclasses import dataclass
18
+ from typing import Dict, List, Optional, Protocol
19
+
20
+ from vllm.config import VllmConfig
21
+ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
22
+
23
+
24
+ @dataclass
25
+ class Stats:
26
+ """Created by LLMEngine for use by StatLogger."""
27
+ now: float
28
+
29
+ # System stats (should have _sys suffix)
30
+ # Scheduler State
31
+ num_running_sys: int
32
+ num_waiting_sys: int
33
+ num_swapped_sys: int
34
+ # KV Cache Usage in %
35
+ gpu_cache_usage_sys: float
36
+ cpu_cache_usage_sys: float
37
+ # Prefix caching block hit rate
38
+ cpu_prefix_cache_hit_rate: float
39
+ gpu_prefix_cache_hit_rate: float
40
+
41
+ # Iteration stats (should have _iter suffix)
42
+ num_prompt_tokens_iter: int
43
+ num_generation_tokens_iter: int
44
+ num_tokens_iter: int
45
+ time_to_first_tokens_iter: List[float]
46
+ time_per_output_tokens_iter: List[float]
47
+ num_preemption_iter: int
48
+
49
+ # Request stats (should have _requests suffix)
50
+ # Latency
51
+ time_e2e_requests: List[float]
52
+ time_queue_requests: List[float]
53
+ time_inference_requests: List[float]
54
+ time_prefill_requests: List[float]
55
+ time_decode_requests: List[float]
56
+ time_in_queue_requests: List[float]
57
+ model_forward_time_requests: List[float]
58
+ model_execute_time_requests: List[float]
59
+ # Metadata
60
+ num_prompt_tokens_requests: List[int]
61
+ num_generation_tokens_requests: List[int]
62
+ n_requests: List[int]
63
+ max_num_generation_tokens_requests: List[int]
64
+ max_tokens_requests: List[int]
65
+ finished_reason_requests: List[str]
66
+ waiting_lora_adapters: List[str]
67
+ running_lora_adapters: List[str]
68
+ max_lora: str
69
+
70
+ spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
71
+
72
+
73
+ class SupportsMetricsInfo(Protocol):
74
+
75
+ def metrics_info(self) -> Dict[str, str]:
76
+ ...
77
+
78
+
79
+ class StatLoggerBase(ABC):
80
+ """Base class for StatLogger."""
81
+
82
+ def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
83
+ # Tracked stats over current local logging interval.
84
+ self.num_prompt_tokens: List[int] = []
85
+ self.num_generation_tokens: List[int] = []
86
+ self.last_local_log = time.time()
87
+ self.local_interval = local_interval
88
+ self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
89
+
90
+ @abstractmethod
91
+ def log(self, stats: Stats) -> None:
92
+ raise NotImplementedError
93
+
94
+ @abstractmethod
95
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
96
+ raise NotImplementedError
97
+
98
+ def maybe_update_spec_decode_metrics(self, stats: Stats):
99
+ """Save spec decode metrics (since they are unlikely
100
+ to be emitted at same time as log interval)."""
101
+ if stats.spec_decode_metrics is not None:
102
+ self.spec_decode_metrics = stats.spec_decode_metrics
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__init__.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import uuid
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import List, Mapping, Optional, Union, overload
7
+
8
+ from typing_extensions import deprecated
9
+
10
+ from vllm import PoolingParams
11
+ from vllm.inputs import PromptType
12
+ from vllm.lora.request import LoRARequest
13
+ from vllm.outputs import RequestOutput
14
+ from vllm.prompt_adapter.request import PromptAdapterRequest
15
+ from vllm.sampling_params import SamplingParams
16
+ from vllm.utils import deprecate_kwargs
17
+
18
+ VLLM_RPC_SUCCESS_STR = "SUCCESS"
19
+
20
+ IPC_INPUT_EXT = "_input_socket"
21
+ IPC_OUTPUT_EXT = "_output_socket"
22
+ IPC_HEALTH_EXT = "_health_socket"
23
+ IPC_DATA_EXT = "_data_socket"
24
+
25
+
26
+ class MQEngineDeadError(RuntimeError):
27
+ pass
28
+
29
+
30
+ @dataclass
31
+ class RPCProcessRequest:
32
+ prompt: PromptType
33
+ params: Union[SamplingParams, PoolingParams]
34
+ request_id: str
35
+ lora_request: Optional[LoRARequest] = None
36
+ trace_headers: Optional[Mapping[str, str]] = None
37
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None
38
+ priority: int = 0
39
+
40
+ @overload
41
+ def __init__(
42
+ self,
43
+ prompt: PromptType,
44
+ params: Union[SamplingParams, PoolingParams],
45
+ request_id: str,
46
+ lora_request: Optional[LoRARequest] = None,
47
+ trace_headers: Optional[Mapping[str, str]] = None,
48
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
49
+ priority: int = 0,
50
+ ) -> None:
51
+ ...
52
+
53
+ @overload
54
+ @deprecated("'inputs' will be renamed to 'prompt")
55
+ def __init__(
56
+ self,
57
+ *,
58
+ inputs: PromptType,
59
+ params: Union[SamplingParams, PoolingParams],
60
+ request_id: str,
61
+ lora_request: Optional[LoRARequest] = None,
62
+ trace_headers: Optional[Mapping[str, str]] = None,
63
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
64
+ priority: int = 0,
65
+ ) -> None:
66
+ ...
67
+
68
+ @deprecate_kwargs(
69
+ "inputs",
70
+ additional_message="Please use the 'prompt' parameter instead.",
71
+ )
72
+ def __init__(
73
+ self,
74
+ prompt: Optional[PromptType] = None,
75
+ params: Optional[Union[SamplingParams, PoolingParams]] = None,
76
+ request_id: Optional[str] = None,
77
+ lora_request: Optional[LoRARequest] = None,
78
+ trace_headers: Optional[Mapping[str, str]] = None,
79
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
80
+ priority: int = 0,
81
+ *,
82
+ inputs: Optional[PromptType] = None, # DEPRECATED
83
+ ) -> None:
84
+ if inputs is not None:
85
+ prompt = inputs
86
+ assert (prompt is not None and params is not None
87
+ and request_id is not None)
88
+
89
+ super().__init__()
90
+
91
+ self.prompt = prompt
92
+ self.params = params
93
+ self.request_id = request_id
94
+ self.lora_request = lora_request
95
+ self.trace_headers = trace_headers
96
+ self.prompt_adapter_request = prompt_adapter_request
97
+ self.priority = priority
98
+
99
+
100
+ @dataclass
101
+ class RPCError:
102
+ request_id: Optional[str]
103
+ is_engine_errored: bool
104
+ exception: BaseException
105
+
106
+
107
+ @dataclass
108
+ class RPCAbortRequest:
109
+ request_id: str
110
+
111
+
112
+ class RPCStartupRequest(Enum):
113
+ IS_SERVER_READY = 1
114
+
115
+
116
+ @dataclass
117
+ class RPCStartupResponse:
118
+ tracing_enabled: bool
119
+
120
+
121
+ class RPCUProfileRequest(Enum):
122
+ START_PROFILE = 1
123
+ STOP_PROFILE = 2
124
+
125
+
126
+ class RPCResetPrefixCacheRequest(Enum):
127
+ RESET_PREFIX_CACHE = 1
128
+
129
+
130
+ @dataclass
131
+ class RPCLoadAdapterRequest:
132
+ lora_request: LoRARequest
133
+ # Set the default value of request_id to a new UUID
134
+ request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
135
+
136
+
137
+ @dataclass
138
+ class RPCAdapterLoadedResponse:
139
+ request_id: str
140
+
141
+
142
+ RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
143
+ RPCUProfileRequest, RPCLoadAdapterRequest,
144
+ RPCResetPrefixCacheRequest]
145
+
146
+ REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
147
+ RPCError]
148
+
149
+
150
+ def ENGINE_DEAD_ERROR(
151
+ error: Optional[BaseException] = None) -> MQEngineDeadError:
152
+ if error is None:
153
+ return MQEngineDeadError(
154
+ "Engine loop is not running. Inspect the stacktrace to "
155
+ "find the original error")
156
+
157
+ return MQEngineDeadError(
158
+ "Engine loop is not running. Inspect the stacktrace to "
159
+ f"find the original error: {repr(error)}.")
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.71 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/client.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/engine.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/client.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ import copy
5
+ import pickle
6
+ from contextlib import contextmanager, suppress
7
+ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
8
+ Optional, Union, cast, overload)
9
+
10
+ import cloudpickle
11
+ import psutil
12
+ import zmq
13
+ import zmq.asyncio
14
+ from typing_extensions import deprecated
15
+ from zmq import Frame # type: ignore[attr-defined]
16
+ from zmq.asyncio import Socket
17
+
18
+ from vllm import PoolingParams
19
+ from vllm.config import DecodingConfig, ModelConfig, VllmConfig
20
+ from vllm.core.scheduler import SchedulerOutputs
21
+ from vllm.engine.arg_utils import AsyncEngineArgs
22
+ # yapf conflicts with isort for this block
23
+ # yapf: disable
24
+ from vllm.engine.async_llm_engine import (
25
+ build_guided_decoding_logits_processor_async)
26
+ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
27
+ IPC_HEALTH_EXT, IPC_INPUT_EXT,
28
+ IPC_OUTPUT_EXT, RPC_REQUEST_T,
29
+ VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
30
+ RPCAdapterLoadedResponse, RPCError,
31
+ RPCLoadAdapterRequest,
32
+ RPCProcessRequest,
33
+ RPCResetPrefixCacheRequest,
34
+ RPCStartupRequest, RPCStartupResponse,
35
+ RPCUProfileRequest)
36
+ from vllm.engine.protocol import EngineClient
37
+ # yapf: enable
38
+ from vllm.envs import VLLM_RPC_TIMEOUT
39
+ from vllm.inputs import PromptType
40
+ from vllm.inputs.preprocess import InputPreprocessor
41
+ from vllm.logger import init_logger
42
+ from vllm.lora.request import LoRARequest
43
+ from vllm.model_executor.layers.sampler import SamplerOutput
44
+ from vllm.outputs import PoolingRequestOutput, RequestOutput
45
+ from vllm.prompt_adapter.request import PromptAdapterRequest
46
+ from vllm.sampling_params import SamplingParams
47
+ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
48
+ from vllm.utils import deprecate_kwargs
49
+
50
+ logger = init_logger(__name__)
51
+
52
+
53
+ class MQClientClosedError(Exception):
54
+ """Exception class raised when the client is used post-close.
55
+
56
+ The client can be closed, which closes the ZMQ context. This normally
57
+ happens on server shutdown. In some cases, methods like abort and
58
+ do_log_stats will still be called and then try to open a socket, which
59
+ causes a ZMQError and creates a huge stack trace.
60
+ So, we throw this error such that we can suppress it.
61
+ """
62
+
63
+
64
+ class MQLLMEngineClient(EngineClient):
65
+ """A client wrapper for MQLLMEngine that conforms to the
66
+ EngineClient protocol.
67
+
68
+ MQLLMEngine and MQLLMEngineClient are intended to run in separate
69
+ processes communicating via zeromq ipc sockets.
70
+
71
+ The entrypoint to MQLLMEngineClient is through the generate()
72
+ method. On generate() MQLLMEngine does three things:
73
+ - Creates an asyncio output queue
74
+ - Sends a RPCGenerateRequest to the MQLLMEngine via zmq
75
+ - Pulls RequestOutputs from its queue and yields them
76
+
77
+ MQLLMEngine runs two background loops:
78
+ - output_loop: the output loop pulls List[RequestOutput]
79
+ from the MQLLMEngine via zmq (each list is the output
80
+ of one engine_step in the LLMEngine). It then parses
81
+ the list and pushes individual request_outputs into
82
+ the corresponding output_queue such that they can be
83
+ consumed by the .generate() method.
84
+ - health_loop: the health loop queries the health socket
85
+ every N seconds, confirming the engine is healthy
86
+ """
87
+
88
+ def __init__(self, ipc_path: str, engine_config: VllmConfig,
89
+ engine_pid: int):
90
+ self.context = zmq.asyncio.Context()
91
+ self._errored_with: Optional[BaseException] = None
92
+
93
+ # Get the configs.
94
+ self.model_config = engine_config.model_config
95
+ self.decoding_config = engine_config.decoding_config
96
+
97
+ # Create the tokenizer group.
98
+ self.tokenizer = init_tokenizer_from_configs(
99
+ model_config=self.model_config,
100
+ scheduler_config=engine_config.scheduler_config,
101
+ parallel_config=engine_config.parallel_config,
102
+ lora_config=engine_config.lora_config)
103
+ self.input_preprocessor = InputPreprocessor(self.model_config,
104
+ self.tokenizer)
105
+
106
+ # Send RPCGenerateRequest to the MQLLMEngine.
107
+ self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
108
+ self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
109
+
110
+ # Receive streams of RequestOutput from the MQLLMEngine.
111
+ self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
112
+ self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
113
+
114
+ # IPC path for acking heartbeats.
115
+ self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
116
+ self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
117
+
118
+ # IPC path for the data socket.
119
+ self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
120
+
121
+ # Stream for each individual request.
122
+ self.output_queues: Dict[str, asyncio.Queue] = {}
123
+
124
+ # Loop to handle output of the LLMEngine periodically.
125
+ # Started after the MQLLMEngine is ready so that we can
126
+ # build the Client in an executor to enable clean shutdown.
127
+ self.output_loop: Optional[asyncio.Task] = None
128
+
129
+ # Loop to check health of the LLMEngine periodically.
130
+ # Started after the MQLLMEngine is ready.
131
+ self.health_loop: Optional[asyncio.Task] = None
132
+ self._engine_process = psutil.Process(engine_pid)
133
+
134
+ @staticmethod
135
+ def is_unsupported_config(engine_args: AsyncEngineArgs):
136
+ # Pipeline parallel not yet supported
137
+ return engine_args.pipeline_parallel_size > 1
138
+
139
+ @contextmanager
140
+ def get_data_socket(self) -> Iterator[Socket]:
141
+ socket = self.context.socket(zmq.constants.DEALER)
142
+ try:
143
+ socket.connect(self.data_ipc_path)
144
+ yield socket
145
+ finally:
146
+ socket.close(linger=0)
147
+
148
+ async def run_heartbeat_loop(self, timeout: int):
149
+ """Background loop that continually checks to ensure the engine process
150
+ is still alive.
151
+ """
152
+ try:
153
+ while True:
154
+ # Check if the engine process is running:
155
+ if not self._engine_process.is_running() or (
156
+ self._engine_process.status() == psutil.STATUS_ZOMBIE):
157
+ # NB: is_running() returns True for zombies
158
+ self._set_errored(
159
+ RuntimeError(
160
+ f"Engine process (pid {self._engine_process.pid}) "
161
+ "died."))
162
+ break
163
+
164
+ if await self.heartbeat_socket.poll(timeout=timeout):
165
+ # Heartbeat received- check the message
166
+ await self._check_success(
167
+ error_message="Heartbeat failed.",
168
+ socket=self.heartbeat_socket)
169
+
170
+ logger.debug("Heartbeat successful.")
171
+
172
+ except asyncio.CancelledError:
173
+ logger.debug("Shutting down MQLLMEngineClient check health loop.")
174
+
175
+ except psutil.NoSuchProcess:
176
+ self._set_errored(
177
+ RuntimeError(
178
+ f"Engine process (pid {self._engine_process.pid}) died."))
179
+
180
+ except Exception as e:
181
+ self._set_errored(e)
182
+
183
+ async def run_output_handler_loop(self):
184
+ """Get RequestOutputs from Engine and stream to Request Queues"""
185
+
186
+ try:
187
+ while True:
188
+ # Poll, checking for ENGINE_DEAD
189
+ while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
190
+ ) == 0:
191
+ logger.debug("Waiting for output from MQLLMEngine.")
192
+
193
+ # If errored, alert all running requests.
194
+ if self.errored:
195
+ for queue_j in tuple(self.output_queues.values()):
196
+ queue_j.put_nowait(
197
+ ENGINE_DEAD_ERROR(self._errored_with))
198
+ return
199
+
200
+ message: Frame = await self.output_socket.recv(copy=False)
201
+ request_outputs = pickle.loads(message.buffer)
202
+
203
+ is_error = isinstance(request_outputs,
204
+ (BaseException, RPCError))
205
+ if is_error:
206
+ if isinstance(request_outputs, RPCError):
207
+ rpc_error: RPCError = request_outputs
208
+ request_id = rpc_error.request_id
209
+ exception = rpc_error.exception
210
+ is_engine_errored = rpc_error.is_engine_errored
211
+ else:
212
+ # MPLLMEngine should always return an RPCError to
213
+ # the output_socket when an issue arises.
214
+ # If we are here, we are in a bad state and
215
+ # should shut down the server.
216
+ error: BaseException = request_outputs
217
+ logger.error(
218
+ "Received Exception %s rather than RPCError from "
219
+ "MPLLMEngine. This should never happen.", error)
220
+ request_id = None
221
+ exception = error
222
+ is_engine_errored = True
223
+
224
+ # Set to error state only on engine critical error
225
+ # (and record only the first one)
226
+ if is_engine_errored and not self._errored_with:
227
+ self._errored_with = exception
228
+ # If engine is errored, no matter the type of exception
229
+ # it will no longer be able to receive new requests,
230
+ # therefore we have to inform that the current
231
+ # processed requests failed as well. Send back a dead
232
+ # engine error give this feedback and also give a
233
+ # 'hint' to the server to shutdown next.
234
+ exception = self.dead_error
235
+
236
+ if request_id is None:
237
+ # If request_id is None, then the engine raised an
238
+ # exception for a batch, and we may not know the
239
+ # request that caused it, neither if it was actually
240
+ # caused by any of them (e.g. CUDA OOM). Therefore we
241
+ # broadcast the same exception for all requests.
242
+ for queue_i in tuple(self.output_queues.values()):
243
+ queue_i.put_nowait(exception)
244
+ else:
245
+ queue = self.output_queues.get(request_id)
246
+ if queue is not None:
247
+ queue.put_nowait(exception)
248
+ # Put each output into the appropriate queue.
249
+ elif isinstance(request_outputs, RPCAdapterLoadedResponse):
250
+ self._add_output(request_outputs)
251
+ else:
252
+ for request_output in request_outputs:
253
+ self._add_output(request_output)
254
+
255
+ except asyncio.CancelledError:
256
+ logger.debug("Shutting down MQLLMEngineClient output handler.")
257
+
258
+ def _add_output(self, request_output: Union[RequestOutput,
259
+ RPCAdapterLoadedResponse]):
260
+ queue = self.output_queues.get(request_output.request_id)
261
+ if queue is not None:
262
+ queue.put_nowait(request_output)
263
+
264
+ async def setup(self):
265
+ """Setup the client before it starts sending server requests."""
266
+
267
+ # Start output_loop
268
+ if self.output_loop is None:
269
+ # only generate once to avoid multiple concurrent output_loops
270
+ # this will lead to race conditions and wrong orders of tokens
271
+ # returned by the engine
272
+ # setup will be called multiple times during the startup of
273
+ # the engine
274
+ self.output_loop = asyncio.create_task(
275
+ self.run_output_handler_loop())
276
+
277
+ with self.get_data_socket() as socket:
278
+ # Wait until server is ready.
279
+ response = await self._wait_for_server_rpc(socket)
280
+
281
+ self.tracing_flag = response.tracing_enabled
282
+
283
+ # Start health_loop.
284
+ if self.health_loop is None:
285
+ self.health_loop = asyncio.create_task(
286
+ self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
287
+
288
+ def close(self):
289
+ """Destroy the ZeroMQ Context."""
290
+ # Close all sockets and terminate the context.
291
+ self.context.destroy(linger=0)
292
+
293
+ # Cancel background tasks.
294
+ if self.health_loop is not None:
295
+ self.health_loop.cancel()
296
+ if self.output_loop is not None:
297
+ self.output_loop.cancel()
298
+
299
+ def _set_errored(self, e: BaseException):
300
+ logger.exception(repr(e))
301
+ if self._errored_with is None:
302
+ self._errored_with = e
303
+
304
+ @staticmethod
305
+ async def _send_get_data_rpc_request(request: RPCStartupRequest,
306
+ expected_type: Any,
307
+ error_message: str,
308
+ socket: Socket) -> Any:
309
+ """Send an RPC request that is expecting data back."""
310
+
311
+ # Ping RPCServer with a request.
312
+ await socket.send_multipart((pickle.dumps(request), ), copy=False)
313
+
314
+ # Make sure the server responds in time.
315
+ if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
316
+ raise TimeoutError("RPCServer didn't reply within "
317
+ f"{VLLM_RPC_TIMEOUT} ms")
318
+
319
+ # Await the data from the Server.
320
+ frame = await socket.recv(copy=False)
321
+ data = pickle.loads(frame.buffer)
322
+
323
+ if isinstance(data, BaseException):
324
+ raise data
325
+ elif not isinstance(data, expected_type):
326
+ raise ValueError(error_message)
327
+
328
+ return data
329
+
330
+ @staticmethod
331
+ async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
332
+ socket: Socket):
333
+ """Send one-way RPC request to trigger an action."""
334
+
335
+ if socket.closed:
336
+ raise MQClientClosedError()
337
+
338
+ await socket.send_multipart((pickle.dumps(request), ))
339
+
340
+ async def _await_ack(self, error_message: str, socket: Socket):
341
+ """Await acknowledgement that a request succeeded."""
342
+
343
+ if socket.closed:
344
+ raise MQClientClosedError()
345
+
346
+ if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
347
+ raise TimeoutError("MQLLMEngine didn't reply within "
348
+ f"{VLLM_RPC_TIMEOUT}ms")
349
+
350
+ await self._check_success(error_message, socket)
351
+
352
+ @staticmethod
353
+ async def _check_success(error_message: str, socket: Socket):
354
+ """Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
355
+
356
+ if socket.closed:
357
+ raise MQClientClosedError()
358
+
359
+ frame = await socket.recv(copy=False)
360
+ response = pickle.loads(frame.buffer)
361
+
362
+ # Raise error if unsuccessful
363
+ if isinstance(response, BaseException):
364
+ raise response
365
+ elif (not isinstance(response, str)
366
+ or response != VLLM_RPC_SUCCESS_STR):
367
+ raise ValueError(error_message)
368
+
369
+ async def get_input_preprocessor(self) -> InputPreprocessor:
370
+ return self.input_preprocessor
371
+
372
+ async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
373
+ return await self.tokenizer.get_lora_tokenizer_async(lora_request)
374
+
375
+ async def get_decoding_config(self) -> DecodingConfig:
376
+ return self.decoding_config
377
+
378
+ async def get_model_config(self) -> ModelConfig:
379
+ return self.model_config
380
+
381
+ async def is_tracing_enabled(self) -> bool:
382
+ return self.tracing_flag
383
+
384
+ async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
385
+ """Wait for the RPCServer to start up."""
386
+
387
+ return await self._send_get_data_rpc_request(
388
+ request=RPCStartupRequest.IS_SERVER_READY,
389
+ expected_type=RPCStartupResponse,
390
+ error_message="Unable to start RPC Server",
391
+ socket=socket)
392
+
393
+ async def abort(self, request_id: str):
394
+ """Send an ABORT_REQUEST signal to the RPC Server"""
395
+
396
+ with suppress(MQClientClosedError):
397
+ await self._send_one_way_rpc_request(
398
+ request=RPCAbortRequest(request_id), socket=self.input_socket)
399
+
400
+ async def do_log_stats(
401
+ self,
402
+ scheduler_outputs: Optional[SchedulerOutputs] = None,
403
+ model_output: Optional[List[SamplerOutput]] = None,
404
+ ) -> None:
405
+ """
406
+ Ignore do_log_stats (handled on MQLLMEngine polling)
407
+ """
408
+ pass
409
+
410
+ async def check_health(self):
411
+ """
412
+ The check health loop probes the health status of the
413
+ Engine's health every N seconds and sets _errored_with
414
+ if the engine is unhealthy.
415
+ """
416
+ if self._errored_with is not None:
417
+ raise self._errored_with
418
+
419
+ @property
420
+ def is_running(self) -> bool:
421
+ return not self.errored
422
+
423
+ @property
424
+ def is_stopped(self) -> bool:
425
+ return self.errored
426
+
427
+ @property
428
+ def errored(self) -> bool:
429
+ return self._errored_with is not None
430
+
431
+ @property
432
+ def dead_error(self) -> BaseException:
433
+ return ENGINE_DEAD_ERROR(self._errored_with)
434
+
435
+ @overload
436
+ def generate(
437
+ self,
438
+ prompt: PromptType,
439
+ sampling_params: SamplingParams,
440
+ request_id: str,
441
+ lora_request: Optional[LoRARequest] = None,
442
+ trace_headers: Optional[Mapping[str, str]] = None,
443
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
444
+ priority: int = 0,
445
+ ) -> AsyncGenerator[RequestOutput, None]:
446
+ ...
447
+
448
+ @overload
449
+ @deprecated("'inputs' will be renamed to 'prompt")
450
+ def generate(
451
+ self,
452
+ *,
453
+ inputs: PromptType,
454
+ sampling_params: SamplingParams,
455
+ request_id: str,
456
+ lora_request: Optional[LoRARequest] = None,
457
+ trace_headers: Optional[Mapping[str, str]] = None,
458
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
459
+ priority: int = 0,
460
+ ) -> AsyncGenerator[RequestOutput, None]:
461
+ ...
462
+
463
+ @deprecate_kwargs(
464
+ "inputs",
465
+ additional_message="Please use the 'prompt' parameter instead.",
466
+ )
467
+ def generate(
468
+ self,
469
+ prompt: Optional[PromptType] = None,
470
+ sampling_params: Optional[SamplingParams] = None,
471
+ request_id: Optional[str] = None,
472
+ lora_request: Optional[LoRARequest] = None,
473
+ trace_headers: Optional[Mapping[str, str]] = None,
474
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
475
+ priority: int = 0,
476
+ *,
477
+ inputs: Optional[PromptType] = None # DEPRECATED
478
+ ) -> AsyncGenerator[RequestOutput, None]:
479
+ """Generate outputs for a request.
480
+
481
+ Generate outputs for a request. This method is a coroutine. It adds the
482
+ request into the waiting queue of the LLMEngine and streams the outputs
483
+ from the LLMEngine to the caller.
484
+
485
+ Args:
486
+ prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
487
+ for more details about the format of each input.
488
+ sampling_params: The sampling parameters of the request.
489
+ request_id: The unique id of the request.
490
+ lora_request: LoRA request to use for generation, if any.
491
+ trace_headers: OpenTelemetry trace headers.
492
+ prompt_adapter_request: Prompt Adapter request to use
493
+ for generation, if any.
494
+ priority: Priority of the request (lower means earlier handling).
495
+ Any priority other than 0 will lead to an error if the
496
+ scheduling policy is not "priority".
497
+ """
498
+ if inputs is not None:
499
+ prompt = inputs
500
+ assert (prompt is not None and sampling_params is not None
501
+ and request_id is not None)
502
+
503
+ return self._process_request(prompt, sampling_params, request_id,
504
+ lora_request, trace_headers,
505
+ prompt_adapter_request, priority)
506
+
507
+ @overload
508
+ def encode(
509
+ self,
510
+ prompt: PromptType,
511
+ pooling_params: PoolingParams,
512
+ request_id: str,
513
+ lora_request: Optional[LoRARequest] = None,
514
+ trace_headers: Optional[Mapping[str, str]] = None,
515
+ priority: int = 0,
516
+ ) -> AsyncGenerator[PoolingRequestOutput, None]:
517
+ ...
518
+
519
+ @overload
520
+ @deprecated("'inputs' will be renamed to 'prompt")
521
+ def encode(
522
+ self,
523
+ *,
524
+ inputs: PromptType,
525
+ pooling_params: PoolingParams,
526
+ request_id: str,
527
+ lora_request: Optional[LoRARequest] = None,
528
+ trace_headers: Optional[Mapping[str, str]] = None,
529
+ priority: int = 0,
530
+ ) -> AsyncGenerator[PoolingRequestOutput, None]:
531
+ ...
532
+
533
+ @deprecate_kwargs(
534
+ "inputs",
535
+ additional_message="Please use the 'prompt' parameter instead.",
536
+ )
537
+ def encode(
538
+ self,
539
+ prompt: Optional[PromptType] = None,
540
+ pooling_params: Optional[PoolingParams] = None,
541
+ request_id: Optional[str] = None,
542
+ lora_request: Optional[LoRARequest] = None,
543
+ trace_headers: Optional[Mapping[str, str]] = None,
544
+ priority: int = 0,
545
+ *,
546
+ inputs: Optional[PromptType] = None # DEPRECATED
547
+ ) -> AsyncGenerator[PoolingRequestOutput, None]:
548
+ """Generate outputs for a request from a pooling model.
549
+
550
+ Generate outputs for a request. This method is a coroutine. It adds the
551
+ request into the waiting queue of the LLMEngine and streams the outputs
552
+ from the LLMEngine to the caller.
553
+
554
+ Args:
555
+ prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
556
+ for more details about the format of each input.
557
+ pooling_params: The pooling parameters of the request.
558
+ request_id: The unique id of the request.
559
+ lora_request: LoRA request to use for generation, if any.
560
+ trace_headers: OpenTelemetry trace headers.
561
+
562
+ Yields:
563
+ The output `PoolingRequestOutput` objects from the LLMEngine
564
+ for the request.
565
+ """
566
+ if inputs is not None:
567
+ prompt = inputs
568
+ assert (prompt is not None and pooling_params is not None
569
+ and request_id is not None)
570
+
571
+ return cast(
572
+ AsyncGenerator[PoolingRequestOutput, None],
573
+ self._process_request(prompt,
574
+ pooling_params,
575
+ request_id,
576
+ lora_request,
577
+ trace_headers,
578
+ priority=priority))
579
+
580
+ async def _process_request(
581
+ self,
582
+ prompt: PromptType,
583
+ params: Union[SamplingParams, PoolingParams],
584
+ request_id: str,
585
+ lora_request: Optional[LoRARequest] = None,
586
+ trace_headers: Optional[Mapping[str, str]] = None,
587
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
588
+ priority: int = 0,
589
+ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
590
+ PoolingRequestOutput, None]]:
591
+ """Send an RPCGenerateRequest to the RPCServer and stream responses."""
592
+
593
+ # If already dead, error out.
594
+ if self._errored_with is not None:
595
+ raise ENGINE_DEAD_ERROR(self._errored_with)
596
+
597
+ # Ensure the request id is unique among running requests
598
+ if request_id in self.output_queues:
599
+ raise ValueError(f"Request {request_id} already exists")
600
+
601
+ # Constructing guided decoding logits processors is expensive, so we do
602
+ # it here to avoid contending with cpu resources and the GIL on the
603
+ # backend process.
604
+ if isinstance(params, SamplingParams) and \
605
+ params.guided_decoding is not None:
606
+ params = await \
607
+ build_guided_decoding_logits_processor_async(
608
+ sampling_params=params,
609
+ tokenizer=await self.get_tokenizer(lora_request),
610
+ default_guided_backend=(self.decoding_config.guided_decoding_backend
611
+ if self.decoding_config
612
+ else DecodingConfig.guided_decoding_backend),
613
+ model_config=self.model_config
614
+ )
615
+
616
+ # 1) Create output queue for this requests.
617
+ queue: asyncio.Queue[Union[RequestOutput,
618
+ BaseException]] = asyncio.Queue()
619
+ self.output_queues[request_id] = queue
620
+
621
+ try:
622
+ # 2) Detach logits processors so that they can be pickled
623
+ # separately (may require cloudpickle which is slower)
624
+ if isinstance(params, SamplingParams) and params.logits_processors:
625
+ # Defensive shallow copy
626
+ params = copy.copy(params)
627
+ logits_processors = params.logits_processors
628
+ params.logits_processors = None
629
+ lp_bytes = cloudpickle.dumps(logits_processors)
630
+ else:
631
+ lp_bytes = None
632
+
633
+ request_bytes = pickle.dumps(
634
+ RPCProcessRequest(
635
+ prompt=prompt,
636
+ params=params,
637
+ request_id=request_id,
638
+ lora_request=lora_request,
639
+ trace_headers=trace_headers,
640
+ prompt_adapter_request=prompt_adapter_request,
641
+ priority=priority,
642
+ ))
643
+
644
+ # 3) Send the RPCGenerateRequest to the MQLLMEngine.
645
+ parts = (request_bytes,
646
+ lp_bytes) if lp_bytes else (request_bytes, )
647
+ await self.input_socket.send_multipart(parts, copy=False)
648
+
649
+ # 4) Stream the RequestOutputs from the output queue. Note
650
+ # that the output_loop pushes RequestOutput objects to this
651
+ # queue after pulling them from the zmq socket.
652
+ finished = False
653
+ try:
654
+ while not finished:
655
+ request_output = await queue.get()
656
+
657
+ if isinstance(request_output, BaseException):
658
+ raise request_output
659
+
660
+ finished = request_output.finished
661
+ yield request_output
662
+ finally:
663
+ # Request was canceled by the client.
664
+ if not finished and not self.errored:
665
+ await self.abort(request_id)
666
+ finally:
667
+ self.output_queues.pop(request_id)
668
+
669
+ async def start_profile(self) -> None:
670
+ """Start profiling the engine"""
671
+
672
+ await self._send_one_way_rpc_request(
673
+ request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
674
+
675
+ async def stop_profile(self) -> None:
676
+ """Stop profiling the engine"""
677
+
678
+ await self._send_one_way_rpc_request(
679
+ request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
680
+
681
+ async def reset_prefix_cache(self) -> None:
682
+ """Reset the prefix cache"""
683
+
684
+ await self._send_one_way_rpc_request(
685
+ request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
686
+ socket=self.input_socket)
687
+
688
+ async def add_lora(self, lora_request: LoRARequest) -> None:
689
+ """Load a new LoRA adapter into the engine for future requests."""
690
+ # Uses the same I/O as generate requests
691
+ request = RPCLoadAdapterRequest(lora_request)
692
+
693
+ # Create output queue for this requests.
694
+ queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
695
+ self.output_queues[request.request_id] = queue
696
+
697
+ # Send the request
698
+ request_bytes = pickle.dumps(request)
699
+ await self.input_socket.send_multipart((request_bytes, ), copy=False)
700
+
701
+ # Wait for the response
702
+ request_output = await queue.get()
703
+ self.output_queues.pop(request.request_id)
704
+
705
+ # Raise on error, otherwise happily return None
706
+ if isinstance(request_output, BaseException):
707
+ raise request_output
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import pickle
4
+ import signal
5
+ from contextlib import contextmanager
6
+ from typing import Iterator, List, Optional, Union
7
+
8
+ import cloudpickle
9
+ import zmq
10
+
11
+ from vllm import AsyncEngineArgs, SamplingParams
12
+ from vllm.engine.llm_engine import LLMEngine
13
+ # yapf conflicts with isort for this block
14
+ # yapf: disable
15
+ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
16
+ IPC_HEALTH_EXT, IPC_INPUT_EXT,
17
+ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
18
+ VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
19
+ RPCAdapterLoadedResponse, RPCError,
20
+ RPCLoadAdapterRequest,
21
+ RPCProcessRequest,
22
+ RPCResetPrefixCacheRequest,
23
+ RPCStartupRequest, RPCStartupResponse,
24
+ RPCUProfileRequest)
25
+ # yapf: enable
26
+ from vllm.logger import init_logger
27
+ from vllm.outputs import RequestOutput
28
+ from vllm.usage.usage_lib import UsageContext
29
+
30
+ logger = init_logger(__name__)
31
+
32
+ POLLING_TIMEOUT_MS = 10000
33
+ HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
34
+
35
+
36
+ class MQLLMEngine:
37
+ """A multiprocessing wrapper for :class:`LLMEngine`.
38
+
39
+ This class is used to wrap the :class:`LLMEngine` class to enable use
40
+ in concurrnet manner. It runs a background loop and uses zeromq to
41
+ receive new requests and stream outputs incrementally via ipc.
42
+
43
+ The :class:`LLMEngine` generate or encode process is kicked off when a new
44
+ RPCProcessRequest is received by the input_socket.
45
+
46
+ The self.engine_loop checks the input_socket for new requests,
47
+ adds them to the LLMEngine if there are any, calls the internal
48
+ :class:`LLMEngine.step()`, and sends the RequestOutputs back over
49
+ the output_socket.
50
+
51
+ If use_async_sockets is set, the logic associated with reading new
52
+ requests from the socket and sending data to the socket is passed
53
+ as a callback to the llm_engine, which calls the logic asynchronously
54
+ such that the IPC can be overlapped with the GPU.
55
+
56
+ Args:
57
+ ipc_path: Base path for zeromq interprocess messaging
58
+ use_async_sockets: Whether to make send/recv async with GPU
59
+ log_requests: Whether to log the requests.
60
+ *args: Arguments for :class:`LLMEngine`.
61
+ **kwargs: Arguments for :class:`LLMEngine`.
62
+ """
63
+
64
+ def __init__(self,
65
+ ipc_path: str,
66
+ use_async_sockets: bool,
67
+ *args,
68
+ log_requests: bool = True,
69
+ **kwargs) -> None:
70
+ # For MQLLMEngine, we can use cached outputs, since each new request
71
+ # output is immediately pickled and send over the socket, which frees
72
+ # the python object to be reused again.
73
+ kwargs['use_cached_outputs'] = True
74
+
75
+ self.engine = LLMEngine(*args, **kwargs)
76
+ self.log_requests = log_requests
77
+
78
+ self.use_async_sockets = use_async_sockets
79
+ if self.use_async_sockets:
80
+ self.engine.process_request_outputs_callback = \
81
+ self._async_socket_engine_callback
82
+
83
+ self.ctx = zmq.Context() # type: ignore[attr-defined]
84
+
85
+ # Receive input from the client.
86
+ self.input_socket = self.ctx.socket(zmq.constants.PULL)
87
+ self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
88
+
89
+ # Send output stream back to client.
90
+ self.output_socket = self.ctx.socket(zmq.constants.PUSH)
91
+ self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
92
+
93
+ # Send heartbeats back to client.
94
+ self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
95
+ self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
96
+
97
+ # IPC path for the data socket.
98
+ self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
99
+
100
+ # Error state.
101
+ self._errored_with: Optional[BaseException] = None
102
+
103
+ @property
104
+ def dead_error(self) -> BaseException:
105
+ if self._errored_with is not None:
106
+ return ENGINE_DEAD_ERROR(self._errored_with)
107
+ else:
108
+ return ENGINE_DEAD_ERROR()
109
+
110
+ @classmethod
111
+ def from_engine_args(cls, engine_args: AsyncEngineArgs,
112
+ usage_context: UsageContext, ipc_path: str):
113
+ """Creates an MQLLMEngine from the engine arguments."""
114
+ # Setup plugins for each process
115
+ from vllm.plugins import load_general_plugins
116
+ load_general_plugins()
117
+
118
+ engine_config = engine_args.create_engine_config(usage_context)
119
+ executor_class = LLMEngine._get_executor_cls(engine_config)
120
+
121
+ use_async_sockets = engine_config.model_config.use_async_output_proc
122
+
123
+ return cls(ipc_path=ipc_path,
124
+ use_async_sockets=use_async_sockets,
125
+ vllm_config=engine_config,
126
+ executor_class=executor_class,
127
+ log_requests=not engine_args.disable_log_requests,
128
+ log_stats=not engine_args.disable_log_stats,
129
+ usage_context=usage_context)
130
+
131
+ def start(self):
132
+ try:
133
+ try:
134
+ logger.debug("Starting Startup Loop.")
135
+ self.run_startup_loop()
136
+ logger.debug("Starting Engine Loop.")
137
+ self.run_engine_loop()
138
+ except Exception as e:
139
+ logger.exception(repr(e))
140
+ except KeyboardInterrupt:
141
+ logger.debug("Shutting down MQLLMEngine.")
142
+ finally:
143
+ logger.debug("MQLLMEngine is shut down.")
144
+ self.cleanup()
145
+
146
+ def cleanup(self):
147
+ """Cleanup zeromq state on shutdown."""
148
+ # Closes all sockets and destroys context.
149
+ self.ctx.destroy(linger=0)
150
+ del self.engine
151
+
152
+ @contextmanager
153
+ def make_data_socket(
154
+ self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
155
+ socket = self.ctx.socket(zmq.constants.ROUTER)
156
+ try:
157
+ socket.bind(self.data_ipc_path)
158
+ yield socket
159
+ finally:
160
+ socket.close(linger=0)
161
+
162
+ def run_startup_loop(self) -> None:
163
+ """Startup loop for sending data from Engine -> Client."""
164
+
165
+ with self.make_data_socket() as socket:
166
+ response: Union[RPCStartupResponse, BaseException]
167
+ try:
168
+ identity, message = socket.recv_multipart(copy=False)
169
+ request: RPCStartupRequest = pickle.loads(message.buffer)
170
+
171
+ # Handle the query from the Client.
172
+ if request == RPCStartupRequest.IS_SERVER_READY:
173
+ tracing_enabled = self.engine.is_tracing_enabled()
174
+ response = RPCStartupResponse(
175
+ tracing_enabled=tracing_enabled)
176
+
177
+ except Exception as e:
178
+ response = e
179
+
180
+ socket.send_multipart((identity, pickle.dumps(response)),
181
+ copy=False)
182
+
183
+ def run_engine_loop(self):
184
+ """Core busy loop of the LLMEngine."""
185
+
186
+ while True:
187
+ if not self.engine.has_unfinished_requests():
188
+ # Poll until there is work to do.
189
+ while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
190
+ # When there's no work, check on engine health and send
191
+ # health status back to client
192
+ self._health_check()
193
+ self.engine.do_log_stats()
194
+ logger.debug("Waiting for new requests in engine loop.")
195
+
196
+ # Handle any input from the client.
197
+ self.handle_new_input()
198
+
199
+ # Engine step.
200
+ request_outputs = self.engine_step()
201
+
202
+ # Send request outputs (if async, done in engine_step callback).
203
+ if not self.use_async_sockets:
204
+ self._send_outputs(request_outputs)
205
+
206
+ def engine_step(self) -> List[RequestOutput]:
207
+ """Engine step wrapper with error handling."""
208
+ try:
209
+ return self.engine.step()
210
+ except SystemExit:
211
+ raise
212
+ except BaseException as e:
213
+ self._set_errored(e)
214
+ rpc_err = RPCError(request_id=None,
215
+ is_engine_errored=True,
216
+ exception=e)
217
+ self._send_outputs(rpc_err)
218
+ raise e
219
+
220
+ def handle_new_input(self):
221
+ """Handle new input from the socket"""
222
+ try:
223
+ while self.input_socket.poll(timeout=0) != 0:
224
+ frames = self.input_socket.recv_multipart(copy=False)
225
+ request = pickle.loads(frames[0].buffer)
226
+
227
+ if isinstance(request, RPCProcessRequest):
228
+ if len(frames) > 1:
229
+ # Use cloudpickle for logits processors
230
+ assert isinstance(request.params, SamplingParams)
231
+ lprocs = cloudpickle.loads(frames[1].buffer)
232
+ request.params.logits_processors = lprocs
233
+ self._handle_process_request(request)
234
+ elif isinstance(request, RPCAbortRequest):
235
+ self._handle_abort_request(request)
236
+ elif isinstance(request, RPCUProfileRequest):
237
+ if request == RPCUProfileRequest.START_PROFILE:
238
+ self.start_profile()
239
+ else:
240
+ self.stop_profile()
241
+ elif isinstance(request, RPCLoadAdapterRequest):
242
+ self._handle_load_adapter_request(request)
243
+ elif isinstance(request, RPCResetPrefixCacheRequest):
244
+ self.reset_prefix_cache()
245
+ else:
246
+ raise ValueError("Unknown RPCRequest Type: "
247
+ f"{type(request)}")
248
+
249
+ except Exception as e:
250
+ self._set_errored(e)
251
+ self._send_unhealthy(e)
252
+ raise e
253
+
254
+ def _handle_process_request(self, request: RPCProcessRequest):
255
+ """Handle RPCProcessRequest by adding it to the LLMEngine."""
256
+ request_id = request.request_id
257
+
258
+ if self._errored_with is not None:
259
+ rpc_err = RPCError(request_id=request_id,
260
+ is_engine_errored=True,
261
+ exception=ENGINE_DEAD_ERROR(self._errored_with))
262
+ self._send_outputs(rpc_err)
263
+
264
+ try:
265
+ self.engine.add_request(
266
+ request_id=request_id,
267
+ prompt=request.prompt,
268
+ params=request.params,
269
+ lora_request=request.lora_request,
270
+ trace_headers=request.trace_headers,
271
+ prompt_adapter_request=request.prompt_adapter_request,
272
+ priority=request.priority)
273
+
274
+ if self.log_requests:
275
+ logger.info("Added request %s.", request.request_id)
276
+
277
+ except Exception as e:
278
+ # We do not set self._errored = True here, since the error
279
+ # is due to an issue adding this request to the engine,
280
+ # rather than an issue with the engine itself.
281
+ is_errored = self._errored_with is not None
282
+ rpc_err = RPCError(request_id=request_id,
283
+ is_engine_errored=is_errored,
284
+ exception=e)
285
+ self._send_outputs(rpc_err)
286
+
287
+ # Remove request from the engine.
288
+ self.engine.abort_request(request_id)
289
+
290
+ def _handle_abort_request(self, request: RPCAbortRequest):
291
+ self.engine.abort_request(request.request_id)
292
+ if self.log_requests:
293
+ logger.info("Aborted request %s.", request.request_id)
294
+
295
+ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
296
+ try:
297
+ self.engine.add_lora(request.lora_request)
298
+ except BaseException as e:
299
+ # Send back an error if the adater fails to load
300
+ rpc_err = RPCError(request_id=request.request_id,
301
+ is_engine_errored=False,
302
+ exception=e)
303
+ self._send_outputs(rpc_err)
304
+ return
305
+ # Otherwise, send back the successful load message
306
+ self._send_outputs(
307
+ RPCAdapterLoadedResponse(request_id=request.request_id))
308
+
309
+ def _health_check(self):
310
+ # Send unhealthy if engine has already errored
311
+ if self._errored_with is not None:
312
+ self._send_unhealthy(self._errored_with)
313
+ try:
314
+ self.engine.check_health()
315
+ self._send_healthy()
316
+ except Exception as e:
317
+ self._set_errored(e)
318
+ self._send_unhealthy(e)
319
+
320
+ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
321
+ """Send outputs back to the engine client. These can be:
322
+ - Exceptions
323
+ - A list of generation outputs
324
+ - A response from loading a lora adapter
325
+ """
326
+ if outputs:
327
+ try:
328
+ from ray.exceptions import RayTaskError
329
+
330
+ # RayTaskError might not pickelable here. We need to unpack the
331
+ # underlying exception as the real exception in the output.
332
+ if (isinstance(outputs, RPCError)
333
+ and isinstance(outputs.exception, RayTaskError)):
334
+ outputs.exception = outputs.exception.cause
335
+ except ImportError:
336
+ pass
337
+
338
+ output_bytes = pickle.dumps(outputs)
339
+ self.output_socket.send_multipart((output_bytes, ), copy=False)
340
+
341
+ def _send_healthy(self):
342
+ """Send HEALTHY message to RPCClient."""
343
+ if not self.heartbeat_socket.closed:
344
+ self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
345
+
346
+ def _send_unhealthy(self, error: BaseException):
347
+ """Send UNHEALTHY message to RPCClient."""
348
+ if not self.heartbeat_socket.closed:
349
+ error_bytes = pickle.dumps(error)
350
+ self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
351
+
352
+ def _async_socket_engine_callback(self,
353
+ request_outputs: REQUEST_OUTPUTS_T):
354
+ """Callback used by engine to make socket handling async with GPU."""
355
+ self._send_outputs(request_outputs)
356
+ self.handle_new_input()
357
+
358
+ def _set_errored(self, e: BaseException):
359
+ """Log and set errored status if this is the first issue."""
360
+ if self._errored_with is None:
361
+ self._errored_with = e
362
+
363
+ def start_profile(self) -> None:
364
+ self.engine.start_profile()
365
+
366
+ def stop_profile(self) -> None:
367
+ self.engine.stop_profile()
368
+
369
+ def reset_prefix_cache(self) -> bool:
370
+ return self.engine.reset_prefix_cache()
371
+
372
+
373
+ def signal_handler(*_) -> None:
374
+ raise KeyboardInterrupt("MQLLMEngine terminated")
375
+
376
+
377
+ def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
378
+ ipc_path: str, engine_alive):
379
+ try:
380
+ engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
381
+ usage_context=usage_context,
382
+ ipc_path=ipc_path)
383
+
384
+ signal.signal(signal.SIGTERM, signal_handler)
385
+
386
+ engine.start()
387
+
388
+ except BaseException as e:
389
+ logger.exception(e)
390
+ engine_alive.value = False
391
+ raise e
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/interfaces.cpython-311.pyc ADDED
Binary file (3.78 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/multi_step.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/single_step.cpython-311.pyc ADDED
Binary file (7.06 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/stop_checker.cpython-311.pyc ADDED
Binary file (5.06 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/util.cpython-311.pyc ADDED
Binary file (1.66 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/interfaces.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Callable, List
5
+
6
+ from vllm.config import SchedulerConfig
7
+ from vllm.core.scheduler import Scheduler
8
+ from vllm.engine.output_processor.stop_checker import StopChecker
9
+ from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
10
+ from vllm.transformers_utils.detokenizer import Detokenizer
11
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
12
+ from vllm.utils import Counter
13
+
14
+
15
+ class SequenceGroupOutputProcessor(ABC):
16
+ """Interface for logic that processes new token ids in sequence groups,
17
+ managing detokenization, stop checking, and freeing/forking sequences with
18
+ the scheduler.
19
+
20
+ This is highly coupled with the LLMEngine and should be seen as an extension
21
+ of it. The logic is separated to simplify the LLMEngine class and allow
22
+ separate implementations for single-step decoding (which supports beam
23
+ search sequence forking) and multi-step decoding (which does not support
24
+ beam search, but does support speculative decoding).
25
+ """
26
+
27
+ @staticmethod
28
+ def create_output_processor(
29
+ scheduler_config: SchedulerConfig,
30
+ detokenizer: Detokenizer,
31
+ scheduler: List[Scheduler],
32
+ seq_counter: Counter,
33
+ get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
34
+ stop_checker: "StopChecker",
35
+ ):
36
+ """Create an output processor.
37
+
38
+ This returns a single-step output processor if num_lookahead_slots is
39
+ zero, else returns a multi-step output processor.
40
+ """
41
+ if scheduler_config.num_lookahead_slots == 0:
42
+ # Importing here to avoid cycle.
43
+ from vllm.engine.output_processor.single_step import (
44
+ SingleStepOutputProcessor)
45
+ return SingleStepOutputProcessor(scheduler_config, detokenizer,
46
+ scheduler, seq_counter,
47
+ stop_checker)
48
+ else:
49
+ # Importing here to avoid cycle.
50
+ from vllm.engine.output_processor.multi_step import (
51
+ MultiStepOutputProcessor)
52
+ return MultiStepOutputProcessor(
53
+ detokenizer,
54
+ scheduler,
55
+ seq_counter,
56
+ get_tokenizer_for_seq,
57
+ stop_checker,
58
+ )
59
+
60
+ @abstractmethod
61
+ def process_outputs(self, sequence_group: SequenceGroup,
62
+ outputs: List[SequenceGroupOutput],
63
+ is_async: bool) -> None:
64
+ """Process new token ids for the sequence group. Handles logic such as
65
+ detokenization, stop checking, and freeing/forking sequences in the
66
+ scheduler.
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
72
+ outputs: List[SequenceGroupOutput]) -> None:
73
+ """Update prompt logprobs received from outputs to seq_group."""
74
+ pass
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/multi_step.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import functools
4
+ from typing import Callable, List, cast
5
+
6
+ from vllm.core.scheduler import Scheduler
7
+ from vllm.engine.output_processor.interfaces import (
8
+ SequenceGroupOutputProcessor)
9
+ from vllm.engine.output_processor.single_step import (
10
+ single_step_process_prompt_logprob)
11
+ from vllm.engine.output_processor.stop_checker import StopChecker
12
+ from vllm.logger import init_logger
13
+ from vllm.sampling_params import SamplingParams
14
+ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
15
+ CompletionSequenceGroupOutput, Sequence,
16
+ SequenceGroup, SequenceGroupOutput, SequenceOutput,
17
+ SequenceStatus)
18
+ from vllm.transformers_utils.detokenizer import Detokenizer
19
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
20
+ from vllm.utils import Counter
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
26
+ """SequenceGroupOutputProcessor which handles logic related to
27
+ detokenization and stopping conditions. It specializes to "multi-step
28
+ decoding", where vLLM's worker may generate multiple tokens per invocation.
29
+ This is currently mutually exclusive with advanced sampling techniques like
30
+ beam search, which motivates the separation of this logic from the single
31
+ step output processor.
32
+
33
+ This class is responsible for things such as correctly appending all new
34
+ token ids to their sequence, detokenizing new token ids, truncating new
35
+ output tokens after an eos token, and correctly handling the case where the
36
+ number of new output tokens per sequence differs in a single batch.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ detokenizer: Detokenizer,
42
+ scheduler: List[Scheduler],
43
+ seq_counter: Counter,
44
+ get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
45
+ stop_checker: StopChecker,
46
+ ):
47
+ self.detokenizer = detokenizer
48
+ self.scheduler = scheduler
49
+ self.seq_counter = seq_counter
50
+ self.get_tokenizer_for_seq = get_tokenizer_for_seq
51
+ self.stop_checker = stop_checker
52
+
53
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
54
+ outputs: List[SequenceGroupOutput]) -> None:
55
+ """Process prompt logprobs associated with each step of a multi-step-
56
+ scheduled computation.
57
+
58
+ Args:
59
+ seq_group: the outputs are associated with this :class:`SequenceGroup`
60
+ outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
61
+ """
62
+ for output in outputs:
63
+ # Concatenate single-step prompt logprob processing results.
64
+ assert isinstance(output, CompletionSequenceGroupOutput)
65
+ single_step_process_prompt_logprob(self, seq_group, output)
66
+
67
+ @staticmethod
68
+ @functools.lru_cache
69
+ def _log_prompt_logprob_unsupported_warning_once():
70
+ # Reminder: Please update docs/source/features/compatibility_matrix.md
71
+ # If the feature combo become valid
72
+ logger.warning(
73
+ "Prompt logprob is not supported by multi step workers. "
74
+ "(e.g., speculative decode uses multi step workers).")
75
+
76
+ def process_outputs(self,
77
+ sequence_group: SequenceGroup,
78
+ outputs: List[SequenceGroupOutput],
79
+ is_async: bool = False) -> None:
80
+ """Append new tokens in the outputs to sequences in the sequence group.
81
+
82
+ This only supports sequence groups of size 1. It supports greater than
83
+ one new token per sequence.
84
+
85
+ This applies logic like stop condition checking and detokenization.
86
+ It also handles cases where there are tokens emitted after
87
+ the EOS token.
88
+
89
+ is_async - Indicates whether this postprocessor runs in
90
+ parallel with the GPU forward pass and is processing
91
+ tokens from the previous step. If this is true, then
92
+ no tokens need to be appended since it is already done
93
+ externally (before the next schedule() call)
94
+ """
95
+ # Sequences can be in RUNNING or FINISHED_ABORTED state
96
+ # once scheduled, as a sequence is moved to FINSIHED_ABORTED
97
+ # if a client disconnects from the api server.
98
+ seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
99
+ if seqs is None:
100
+ seqs = sequence_group.get_seqs(
101
+ status=SequenceStatus.FINISHED_ABORTED)
102
+
103
+ assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
104
+ assert len(seqs) == 1, (
105
+ "Beam search not supported in multi-step decoding.")
106
+ seq = seqs[0]
107
+ seq_id = seq.seq_id
108
+ # This method is defined in the more generic
109
+ # SequenceGroupOutputProcessor, but here we assume that the outputs are
110
+ # of a more specific type.
111
+ assert all([
112
+ isinstance(output, CompletionSequenceGroupOutput)
113
+ for output in outputs
114
+ ])
115
+ compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
116
+ assert all([
117
+ seq_id == output.samples[0].parent_seq_id
118
+ for output in compl_outputs
119
+ ])
120
+
121
+ if is_async:
122
+ # Async case: We process tokens one by one. Here, we know the token
123
+ # was already appended, so we only need to do the rest of the
124
+ # postprocessor: Detokenization + stopping logic
125
+ self._process_decode_and_stop(seq, sequence_group.sampling_params)
126
+ else:
127
+ # Standard multi-step case
128
+
129
+ # Since there's only one sequence per sequence group,
130
+ # we can take the first sample.
131
+ samples = [output.samples[0] for output in compl_outputs]
132
+
133
+ # entries in sample tokens may be invalid (eg. due to spec decode
134
+ # rejecting tokens).
135
+ valid_samples = [
136
+ sample for sample in samples
137
+ if sample.output_token != VLLM_INVALID_TOKEN_ID
138
+ ]
139
+
140
+ # When both spec-decode and pre-fill chunking are enabled, we
141
+ # don't have guaranteed samples here (e.g. all -1s).
142
+ if valid_samples:
143
+ self._process_seq_outputs(seq, valid_samples,
144
+ sequence_group.sampling_params)
145
+
146
+ def _process_decode_and_stop(self, seq: Sequence,
147
+ sampling_params: SamplingParams) -> None:
148
+ new_char_count = 0
149
+ if sampling_params.detokenize and self.detokenizer:
150
+ new_char_count = self.detokenizer.decode_sequence_inplace(
151
+ seq, sampling_params)
152
+
153
+ # TODO(sang): Support lora.
154
+ self.stop_checker.maybe_stop_sequence(
155
+ seq,
156
+ new_char_count=new_char_count,
157
+ sampling_params=sampling_params,
158
+ )
159
+
160
+ def _process_seq_outputs(self, seq: Sequence,
161
+ valid_samples: List[SequenceOutput],
162
+ sampling_params: SamplingParams) -> None:
163
+ output_token_ids = [sample.output_token for sample in valid_samples]
164
+ output_logprobs = [sample.logprobs for sample in valid_samples]
165
+
166
+ # Truncate to max_tokens if necessary.
167
+ remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
168
+ len(output_token_ids))
169
+ if remaining_tokens < 0:
170
+ output_token_ids = output_token_ids[:remaining_tokens]
171
+
172
+ # Truncate any tokens after EOS. This is required as spec decode
173
+ # generates a fixed number of tokens without evaluating stopping
174
+ # conditions within the block. This can cause an eos token to be
175
+ # unintentionally ignored.
176
+ if not sampling_params.ignore_eos:
177
+ eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
178
+ # Avoiding .index calls as exception throwing in the happy path
179
+ # is expensive.
180
+ for i in range(len(output_token_ids)):
181
+ if output_token_ids[i] == eos_token_id:
182
+ output_token_ids = output_token_ids[:i + 1]
183
+ break
184
+
185
+ is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
186
+ # Incrementally append tokens to the sequence, as if we had only one new
187
+ # token.
188
+ for output_token_id, output_logprob in zip(output_token_ids,
189
+ output_logprobs):
190
+ seq.append_token_id(
191
+ token_id=output_token_id,
192
+ logprobs=output_logprob,
193
+ )
194
+
195
+ if is_prefill_sampled_token:
196
+ is_prefill_sampled_token = False
197
+ else:
198
+ # Update num_computed_tokens iff the sampled token is not from
199
+ # a prefill step.
200
+ seq.data.update_num_computed_tokens(1)
201
+
202
+ self._process_decode_and_stop(seq, sampling_params)
203
+
204
+ if seq.is_finished():
205
+ break
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/single_step.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import List
4
+
5
+ from vllm.config import SchedulerConfig
6
+ from vllm.core.scheduler import Scheduler
7
+ from vllm.engine.output_processor.interfaces import (
8
+ SequenceGroupOutputProcessor)
9
+ from vllm.engine.output_processor.stop_checker import StopChecker
10
+ from vllm.logger import init_logger
11
+ from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
12
+ SequenceGroupOutput)
13
+ from vllm.transformers_utils.detokenizer import Detokenizer
14
+ from vllm.utils import Counter
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ def single_step_process_prompt_logprob(
20
+ sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
21
+ output: CompletionSequenceGroupOutput) -> None:
22
+ """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
23
+ for a given step.
24
+
25
+ Do nothing if the output has no prompt logprobs.
26
+
27
+ Account for the fact that transformers do not compute first-token logprobs.
28
+
29
+ Args:
30
+ sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
31
+ seq_group: the output is associated with this :class:`SequenceGroup`
32
+ output: the :class:`SequenceGroupOutput` for a single scheduler step
33
+ """
34
+ prompt_logprobs = output.prompt_logprobs
35
+
36
+ # If this is the first (or only) "chunk" of the prefill, we need
37
+ # to prepend None to the list of prompt logprobs. The reason for this
38
+ # is that for N prompt tokens, the Sampler will generate N-1 total
39
+ # prompt logprobs during prefill since the token at idx 0 will not
40
+ # have a logprob associated with it.
41
+ if prompt_logprobs is not None:
42
+ if not seq_group.prompt_logprobs:
43
+ prompt_logprobs = [None] + prompt_logprobs
44
+ seq_group.prompt_logprobs = []
45
+
46
+ assert hasattr(sg_output_proc, 'detokenizer')
47
+ if (seq_group.sampling_params.detokenize
48
+ and sg_output_proc.detokenizer):
49
+ sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
50
+ seq_group,
51
+ prompt_logprobs,
52
+ position_offset=len(seq_group.prompt_logprobs))
53
+
54
+ seq_group.prompt_logprobs.extend(prompt_logprobs)
55
+
56
+
57
+ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
58
+ """SequenceGroupOutputProcessor which handles "output processing" logic,
59
+ which happens after the model returns generated token ids and before
60
+ scheduling of the next batch. Output processing logic includes
61
+ detokenization, and determining if a sequence is finished (e.g. via max len
62
+ or eos token).
63
+
64
+ The SingleStepOutputProcessor is specialized to the case where the model
65
+ emits at most a single token per invocation, which precludes configurations
66
+ such as speculative decoding or multi-step decoding. This enables beam
67
+ search sampling, which requires forking/finishing/freeing sequences in a way
68
+ that is currently difficult to schedule multiple steps ahead of time.
69
+ """
70
+
71
+ def __init__(self, scheduler_config: SchedulerConfig,
72
+ detokenizer: Detokenizer, scheduler: List[Scheduler],
73
+ seq_counter: Counter, stop_checker: StopChecker):
74
+ self.scheduler_config = scheduler_config
75
+ self.detokenizer = detokenizer
76
+ self.scheduler = scheduler
77
+ self.seq_counter = seq_counter
78
+ self.stop_checker = stop_checker
79
+
80
+ def process_outputs(self, sequence_group: SequenceGroup,
81
+ outputs: List[SequenceGroupOutput],
82
+ is_async: bool) -> None:
83
+ """Append all new tokens to sequences in the sequence group. Fork any
84
+ surviving beam candidates; free any unsurviving ones.
85
+
86
+ Invokes detokenizer to detokenize new tokens, and also marks sequences
87
+ as finished if they meet stop conditions.
88
+
89
+ is_async - Indicates whether this postprocessor runs in
90
+ parallel with the GPU forward pass and is processing
91
+ tokens from the previous step. If this is true, then
92
+ no tokens need to be appended since it is already done
93
+ externally (before the next schedule() call)
94
+ """
95
+ assert (len(outputs) == 1
96
+ ), f"{type(self)} does not support multiple outputs per step"
97
+ return self._process_sequence_group_outputs(sequence_group, outputs[0],
98
+ is_async)
99
+
100
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
101
+ outputs: List[SequenceGroupOutput]) -> None:
102
+ """Process prompt logprobs associated with one step of a single-step-
103
+ scheduled computation.
104
+
105
+ Args:
106
+ seq_group: the output is associated with this :class:`SequenceGroup`
107
+ outputs: the :class:`SequenceGroupOutput` for a single scheduler step
108
+ """
109
+ assert len(outputs) == 1, "Single step should only have 1 output."
110
+ output = outputs[0]
111
+ assert isinstance(output, CompletionSequenceGroupOutput)
112
+ single_step_process_prompt_logprob(self, seq_group, output)
113
+
114
+ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
115
+ outputs: SequenceGroupOutput,
116
+ is_async: bool) -> None:
117
+ sampling_params = seq_group.sampling_params
118
+
119
+ sample = outputs.samples[0]
120
+ seq = seq_group.first_seq
121
+ if not is_async:
122
+ seq.append_token_id(sample.output_token, sample.logprobs)
123
+ if sampling_params.detokenize and self.detokenizer:
124
+ new_char_count = self.detokenizer.decode_sequence_inplace(
125
+ seq, sampling_params)
126
+ else:
127
+ new_char_count = 0
128
+ self.stop_checker.maybe_stop_sequence(
129
+ seq,
130
+ new_char_count,
131
+ sampling_params,
132
+ lora_req=seq_group.lora_request,
133
+ )
134
+ if seq.is_finished():
135
+ for scheduler in self.scheduler:
136
+ scheduler.free_seq(seq)
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/stop_checker.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Callable, List, Optional, Tuple
4
+
5
+ from vllm.lora.request import LoRARequest
6
+ from vllm.sampling_params import SamplingParams
7
+ from vllm.sequence import Sequence, SequenceStatus
8
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
9
+
10
+
11
+ class StopChecker:
12
+ """LLMEngine helper class which separates out the logic involving stop
13
+ checking. This checks things such as: whether the eos token was emitted,
14
+ whether the max_tokens has been consumed, whether a stop string has been
15
+ emitted, or if we have exceeded the max model len.
16
+ """
17
+
18
+ def __init__(self, max_model_len: int,
19
+ get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
20
+ # Do not use it directly, but use `self._get_max_model_len`.
21
+ self._max_model_len = max_model_len
22
+ self.get_tokenizer_for_seq = get_tokenizer_for_seq
23
+
24
+ def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
25
+ if lora_req and lora_req.long_lora_max_len:
26
+ return lora_req.long_lora_max_len
27
+ else:
28
+ return self._max_model_len
29
+
30
+ def maybe_stop_sequence(
31
+ self,
32
+ seq: Sequence,
33
+ new_char_count: int,
34
+ sampling_params: SamplingParams,
35
+ lora_req: Optional[LoRARequest] = None,
36
+ ) -> None:
37
+ """Stop the finished sequences.
38
+
39
+ new_char_count is the number of chars added to the
40
+ sequence's output text for the newly generated token
41
+ """
42
+
43
+ # Check if the minimum number of tokens has been generated yet;
44
+ # skip the stop string/token checks if not
45
+ if seq.get_output_len() < sampling_params.min_tokens:
46
+ return
47
+
48
+ # Check if the sequence has generated the EOS token.
49
+ if ((not sampling_params.ignore_eos)
50
+ and seq.get_last_token_id() == seq.eos_token_id):
51
+ # Remove the last EOS token unless explicitly specified
52
+ # This prevents unintended exposure of the EOS token
53
+ if new_char_count and (
54
+ not sampling_params.include_stop_str_in_output):
55
+ seq.output_text = seq.output_text[:-new_char_count]
56
+ seq.status = SequenceStatus.FINISHED_STOPPED
57
+ return
58
+
59
+ # Check if a stop token was encountered.
60
+ # This assumes a single token produced per step.
61
+ last_token_id = seq.get_last_token_id()
62
+ if last_token_id in (sampling_params.stop_token_ids or ()):
63
+ if new_char_count and (
64
+ not sampling_params.include_stop_str_in_output):
65
+ # Remove last token
66
+ seq.output_text = seq.output_text[:-new_char_count]
67
+ seq.status = SequenceStatus.FINISHED_STOPPED
68
+ seq.stop_reason = last_token_id
69
+ return
70
+
71
+ # Check if any stop strings are matched.
72
+ stop = self.check_stop_strings(
73
+ seq.output_text, new_char_count, sampling_params.stop,
74
+ sampling_params.include_stop_str_in_output)
75
+ if stop is not None:
76
+ stop_str, truncate_to = stop
77
+ if truncate_to != -1:
78
+ seq.output_text = seq.output_text[:truncate_to]
79
+ seq.status = SequenceStatus.FINISHED_STOPPED
80
+ seq.stop_reason = stop_str
81
+ return
82
+
83
+ # Check if the sequence has reached max_model_len.
84
+ if seq.get_len() > self._get_max_model_len(lora_req):
85
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
86
+ return
87
+
88
+ # Check if the sequence has reached max_tokens.
89
+ if seq.get_output_len() == sampling_params.max_tokens:
90
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
91
+ return
92
+
93
+ @staticmethod
94
+ def check_stop_strings(
95
+ output_text: str,
96
+ new_char_count: int,
97
+ stop: List[str],
98
+ include_in_output: bool,
99
+ ) -> Optional[Tuple[str, int]]:
100
+ """Check if any stop strings are matched and truncate sequence
101
+ output text accordingly.
102
+
103
+ Returns tuple (stop_string, offset) if matched or else None.
104
+
105
+ Where stop_string is the matched stop string and offset is the
106
+ length to which output_text should be truncated, or -1 for no
107
+ truncation.
108
+ """
109
+ if not new_char_count or not stop:
110
+ return None
111
+
112
+ for stop_str in stop:
113
+ stop_string_len = len(stop_str)
114
+ # Avoid searching already-searched text.
115
+ stop_index = output_text.find(stop_str,
116
+ -new_char_count - stop_string_len)
117
+ if stop_index == -1:
118
+ continue
119
+
120
+ if include_in_output:
121
+ # Truncate to end of stop string.
122
+ stop_index += stop_string_len
123
+ if stop_index >= len(output_text):
124
+ # No truncation required.
125
+ return stop_str, -1
126
+
127
+ # Truncate the output text to either the beginning
128
+ # or end of the stop string.
129
+ return stop_str, stop_index
130
+ return None
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/util.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import List
4
+ from typing import Sequence as GenericSequence
5
+ from typing import cast
6
+
7
+ from vllm.model_executor.layers.sampler import SamplerOutput
8
+ from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
9
+
10
+
11
+ def create_output_by_sequence_group(
12
+ outputs: GenericSequence[SamplerOutput],
13
+ num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
14
+ """Helper method which transforms a 2d list organized by
15
+ [step][sequence group] into [sequence group][step].
16
+ """
17
+ output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
18
+ [] for _ in range(num_seq_groups)
19
+ ]
20
+ for step in outputs:
21
+ sequence_group_output: CompletionSequenceGroupOutput
22
+ for i, sequence_group_output in enumerate(step):
23
+ output_by_sequence_group[i].append(sequence_group_output)
24
+
25
+ # Cast to the more generic type that CompletionSequenceGroupOutput
26
+ # inherits from.
27
+ return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)
.venv/lib/python3.11/site-packages/vllm/engine/protocol.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import asyncio
4
+ from abc import ABC, abstractmethod
5
+ from typing import AsyncGenerator, List, Mapping, Optional
6
+
7
+ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
8
+ from vllm.config import DecodingConfig, ModelConfig
9
+ from vllm.core.scheduler import SchedulerOutputs
10
+ from vllm.inputs.data import PromptType, TokensPrompt
11
+ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
12
+ from vllm.inputs.preprocess import InputPreprocessor
13
+ from vllm.logger import init_logger
14
+ from vllm.lora.request import LoRARequest
15
+ from vllm.model_executor.layers.sampler import SamplerOutput
16
+ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
17
+ from vllm.pooling_params import PoolingParams
18
+ from vllm.prompt_adapter.request import PromptAdapterRequest
19
+ from vllm.sampling_params import BeamSearchParams, SamplingParams
20
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
21
+ from vllm.utils import collect_from_async_generator, random_uuid
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ class EngineClient(ABC):
27
+ """Protocol class for Clients to Engine"""
28
+
29
+ @property
30
+ @abstractmethod
31
+ def is_running(self) -> bool:
32
+ ...
33
+
34
+ @property
35
+ @abstractmethod
36
+ def is_stopped(self) -> bool:
37
+ ...
38
+
39
+ @property
40
+ @abstractmethod
41
+ def errored(self) -> bool:
42
+ ...
43
+
44
+ @property
45
+ @abstractmethod
46
+ def dead_error(self) -> BaseException:
47
+ ...
48
+
49
+ @abstractmethod
50
+ def generate(
51
+ self,
52
+ prompt: PromptType,
53
+ sampling_params: SamplingParams,
54
+ request_id: str,
55
+ lora_request: Optional[LoRARequest] = None,
56
+ trace_headers: Optional[Mapping[str, str]] = None,
57
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
58
+ priority: int = 0,
59
+ ) -> AsyncGenerator[RequestOutput, None]:
60
+ """Generate outputs for a request."""
61
+ ...
62
+
63
+ async def beam_search(
64
+ self,
65
+ prompt: PromptType,
66
+ request_id: str,
67
+ params: BeamSearchParams,
68
+ ) -> AsyncGenerator[RequestOutput, None]:
69
+
70
+ beam_width = params.beam_width
71
+ max_tokens = params.max_tokens
72
+ ignore_eos = params.ignore_eos
73
+ temperature = params.temperature
74
+ length_penalty = params.length_penalty
75
+ include_stop_str_in_output = params.include_stop_str_in_output
76
+
77
+ preprocessor = await self.get_input_preprocessor()
78
+ tokenizer_group = preprocessor.get_tokenizer_group()
79
+ tokenizer = await tokenizer_group.get_lora_tokenizer_async()
80
+
81
+ if is_explicit_encoder_decoder_prompt(prompt):
82
+ raise NotImplementedError
83
+ else:
84
+ processed_inputs = preprocessor._prompt_to_llm_inputs(
85
+ prompt,
86
+ request_id=request_id,
87
+ )
88
+
89
+ prompt_token_ids = processed_inputs["prompt_token_ids"]
90
+ prompt_text = processed_inputs.get("prompt")
91
+ multi_modal_data = processed_inputs.get("multi_modal_data")
92
+ mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
93
+
94
+ tokenized_length = len(prompt_token_ids)
95
+
96
+ sort_beams_key = create_sort_beams_key_function(
97
+ tokenizer.eos_token_id, length_penalty)
98
+
99
+ beam_search_params = SamplingParams(
100
+ logprobs=2 * beam_width,
101
+ max_tokens=1,
102
+ temperature=temperature,
103
+ )
104
+ all_beams = [
105
+ BeamSearchSequence(tokens=prompt_token_ids,
106
+ cum_logprob=0,
107
+ logprobs=[],
108
+ multi_modal_data=multi_modal_data,
109
+ mm_processor_kwargs=mm_processor_kwargs)
110
+ ]
111
+ completed = []
112
+
113
+ for _ in range(max_tokens):
114
+ prompts_batch = [
115
+ TokensPrompt(prompt_token_ids=beam.tokens,
116
+ multi_modal_data=beam.multi_modal_data,
117
+ mm_processor_kwargs=beam.mm_processor_kwargs)
118
+ for beam in all_beams
119
+ ]
120
+
121
+ tasks = []
122
+
123
+ request_id = f"beam_search-{random_uuid()}"
124
+ for i, individual_prompt in enumerate(prompts_batch):
125
+ request_id_item = f"{request_id}-{i}"
126
+ task = asyncio.create_task(
127
+ collect_from_async_generator(
128
+ self.generate(individual_prompt, beam_search_params,
129
+ request_id_item)))
130
+ tasks.append(task)
131
+
132
+ output = await asyncio.gather(*tasks)
133
+
134
+ output = [x[0] for x in output]
135
+
136
+ new_beams = []
137
+ for i, current_beam in enumerate(all_beams):
138
+ result = output[i]
139
+
140
+ if result.outputs[0].logprobs is not None:
141
+ logprobs = result.outputs[0].logprobs[0]
142
+ for token_id, logprob_obj in logprobs.items():
143
+ if token_id == tokenizer.eos_token_id and \
144
+ not ignore_eos:
145
+ completed.append(
146
+ BeamSearchSequence(
147
+ tokens=current_beam.tokens +
148
+ [token_id] if include_stop_str_in_output
149
+ else current_beam.tokens,
150
+ logprobs=current_beam.logprobs +
151
+ [logprobs],
152
+ cum_logprob=current_beam.cum_logprob +
153
+ logprob_obj.logprob,
154
+ finish_reason="stop",
155
+ stop_reason=tokenizer.eos_token_id))
156
+ else:
157
+ new_beams.append(
158
+ BeamSearchSequence(
159
+ tokens=current_beam.tokens + [token_id],
160
+ logprobs=current_beam.logprobs +
161
+ [logprobs],
162
+ cum_logprob=current_beam.cum_logprob +
163
+ logprob_obj.logprob,
164
+ multi_modal_data=current_beam.
165
+ multi_modal_data,
166
+ mm_processor_kwargs=current_beam.
167
+ mm_processor_kwargs))
168
+
169
+ sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
170
+ all_beams = sorted_beams[:beam_width]
171
+
172
+ completed.extend(all_beams)
173
+ sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
174
+ best_beams = sorted_completed[:beam_width]
175
+
176
+ for beam in best_beams:
177
+ if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
178
+ # Skip the eos token in the text.
179
+ tokens = beam.tokens[tokenized_length:-1]
180
+ else:
181
+ tokens = beam.tokens[tokenized_length:]
182
+ beam.text = tokenizer.decode(tokens)
183
+
184
+ beam_search_output = RequestOutput(
185
+ request_id=request_id,
186
+ prompt=prompt_text,
187
+ outputs=[
188
+ CompletionOutput(text=beam.text,
189
+ cumulative_logprob=beam.cum_logprob,
190
+ token_ids=beam.tokens[tokenized_length:],
191
+ index=i,
192
+ logprobs=beam.logprobs,
193
+ finish_reason=beam.finish_reason if
194
+ beam.finish_reason is not None else "length",
195
+ stop_reason=beam.stop_reason)
196
+ for (i, beam) in enumerate(best_beams)
197
+ ],
198
+ finished=True,
199
+ prompt_token_ids=prompt_token_ids,
200
+ prompt_logprobs=None)
201
+
202
+ yield beam_search_output
203
+
204
+ @abstractmethod
205
+ def encode(
206
+ self,
207
+ prompt: PromptType,
208
+ pooling_params: PoolingParams,
209
+ request_id: str,
210
+ lora_request: Optional[LoRARequest] = None,
211
+ trace_headers: Optional[Mapping[str, str]] = None,
212
+ priority: int = 0,
213
+ ) -> AsyncGenerator[PoolingRequestOutput, None]:
214
+ """Generate outputs for a request from a pooling model."""
215
+ ...
216
+
217
+ @abstractmethod
218
+ async def abort(self, request_id: str) -> None:
219
+ """Abort a request.
220
+
221
+ Args:
222
+ request_id: The unique id of the request.
223
+ """
224
+ ...
225
+
226
+ @abstractmethod
227
+ async def get_model_config(self) -> ModelConfig:
228
+ """Get the model configuration of the vLLM engine."""
229
+ ...
230
+
231
+ @abstractmethod
232
+ async def get_decoding_config(self) -> DecodingConfig:
233
+ """Get the decoding configuration of the vLLM engine."""
234
+ ...
235
+
236
+ @abstractmethod
237
+ async def get_input_preprocessor(self) -> InputPreprocessor:
238
+ """Get the input processor of the vLLM engine."""
239
+ ...
240
+
241
+ @abstractmethod
242
+ async def get_tokenizer(
243
+ self,
244
+ lora_request: Optional[LoRARequest] = None,
245
+ ) -> AnyTokenizer:
246
+ """Get the appropriate tokenizer for the request"""
247
+ ...
248
+
249
+ @abstractmethod
250
+ async def is_tracing_enabled(self) -> bool:
251
+ ...
252
+
253
+ @abstractmethod
254
+ async def do_log_stats(
255
+ self,
256
+ scheduler_outputs: Optional[SchedulerOutputs] = None,
257
+ model_output: Optional[List[SamplerOutput]] = None,
258
+ ) -> None:
259
+ ...
260
+
261
+ @abstractmethod
262
+ async def check_health(self) -> None:
263
+ """Raise if unhealthy"""
264
+ ...
265
+
266
+ @abstractmethod
267
+ async def start_profile(self) -> None:
268
+ """Start profiling the engine"""
269
+ ...
270
+
271
+ @abstractmethod
272
+ async def stop_profile(self) -> None:
273
+ """Start profiling the engine"""
274
+ ...
275
+
276
+ @abstractmethod
277
+ async def reset_prefix_cache(self) -> None:
278
+ """Reset the prefix cache"""
279
+ ...
280
+
281
+ @abstractmethod
282
+ async def add_lora(self, lora_request: LoRARequest) -> None:
283
+ """Load a new LoRA adapter into the engine for future requests."""
284
+ ...
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
4
+ from vllm.transformers_utils.configs.cohere2 import Cohere2Config
5
+ from vllm.transformers_utils.configs.dbrx import DbrxConfig
6
+ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
7
+ from vllm.transformers_utils.configs.eagle import EAGLEConfig
8
+ from vllm.transformers_utils.configs.exaone import ExaoneConfig
9
+ # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
10
+ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
11
+ # `FalconConfig` class from the official HuggingFace transformers library.
12
+ from vllm.transformers_utils.configs.falcon import RWConfig
13
+ from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
14
+ from vllm.transformers_utils.configs.internvl import InternVLChatConfig
15
+ from vllm.transformers_utils.configs.jais import JAISConfig
16
+ from vllm.transformers_utils.configs.medusa import MedusaConfig
17
+ from vllm.transformers_utils.configs.mllama import MllamaConfig
18
+ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
19
+ from vllm.transformers_utils.configs.mpt import MPTConfig
20
+ from vllm.transformers_utils.configs.nemotron import NemotronConfig
21
+ from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
22
+ from vllm.transformers_utils.configs.olmo2 import Olmo2Config
23
+ from vllm.transformers_utils.configs.solar import SolarConfig
24
+ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
25
+ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
26
+
27
+ __all__ = [
28
+ "ChatGLMConfig",
29
+ "Cohere2Config",
30
+ "DbrxConfig",
31
+ "DeepseekVLV2Config",
32
+ "MPTConfig",
33
+ "RWConfig",
34
+ "H2OVLChatConfig",
35
+ "InternVLChatConfig",
36
+ "JAISConfig",
37
+ "MedusaConfig",
38
+ "EAGLEConfig",
39
+ "ExaoneConfig",
40
+ "MllamaConfig",
41
+ "MLPSpeculatorConfig",
42
+ "NemotronConfig",
43
+ "NVLM_D_Config",
44
+ "Olmo2Config",
45
+ "SolarConfig",
46
+ "Telechat2Config",
47
+ "UltravoxConfig",
48
+ ]
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.19 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/arctic.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-311.pyc ADDED
Binary file (2.47 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/cohere2.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-311.pyc ADDED
Binary file (8.29 kB). View file