Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/compilation/backends.py +874 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/counter.py +33 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/decorators.py +249 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/fix_functionalization.py +182 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/fusion.py +617 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/fx_utils.py +44 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/monitor.py +38 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/pass_manager.py +79 -0
- .venv/lib/python3.11/site-packages/vllm/compilation/wrapper.py +129 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/arg_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_llm_engine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_timeout.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/llm_engine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/__pycache__/protocol.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/arg_utils.py +1360 -0
- .venv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py +1198 -0
- .venv/lib/python3.11/site-packages/vllm/engine/async_timeout.py +191 -0
- .venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py +2025 -0
- .venv/lib/python3.11/site-packages/vllm/engine/metrics.py +681 -0
- .venv/lib/python3.11/site-packages/vllm/engine/metrics_types.py +102 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__init__.py +159 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/client.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/engine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/client.py +707 -0
- .venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py +391 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/interfaces.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/multi_step.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/single_step.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/stop_checker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/interfaces.py +74 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/multi_step.py +205 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/single_step.py +136 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/stop_checker.py +130 -0
- .venv/lib/python3.11/site-packages/vllm/engine/output_processor/util.py +27 -0
- .venv/lib/python3.11/site-packages/vllm/engine/protocol.py +284 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__init__.py +48 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/arctic.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/cohere2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/compilation/backends.py
ADDED
|
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import copy
|
| 5 |
+
import dataclasses
|
| 6 |
+
import os
|
| 7 |
+
import pprint
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from contextlib import ExitStack
|
| 11 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
| 12 |
+
from unittest.mock import patch
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.fx as fx
|
| 16 |
+
|
| 17 |
+
import vllm.envs as envs
|
| 18 |
+
from vllm.config import CompilationConfig, VllmConfig
|
| 19 |
+
from vllm.logger import init_logger
|
| 20 |
+
from vllm.utils import weak_ref_tensors
|
| 21 |
+
|
| 22 |
+
from .counter import compilation_counter
|
| 23 |
+
from .inductor_pass import InductorPass
|
| 24 |
+
from .monitor import end_monitoring_torch_compile
|
| 25 |
+
from .pass_manager import PostGradPassManager
|
| 26 |
+
|
| 27 |
+
logger = init_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class InductorArtifact:
|
| 32 |
+
hash_str: str = ""
|
| 33 |
+
file_path: str = ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class InductorHashCache:
|
| 37 |
+
"""
|
| 38 |
+
Disk format: a Python list of tuples, each tuple is
|
| 39 |
+
(runtime_shape, graph_index, hash_str, file_path)
|
| 40 |
+
We use list of tuple for readability.
|
| 41 |
+
|
| 42 |
+
In-memory format: a defaultdict of dict, where the key is
|
| 43 |
+
runtime_shape, and the value is a dict of graph_index to hash_str.
|
| 44 |
+
|
| 45 |
+
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
|
| 46 |
+
we don't use json here because json doesn't support int as key.
|
| 47 |
+
|
| 48 |
+
TODO: better off-the-shelf solution to serialize the data?
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, cache_dir: str, disabled: bool = False):
|
| 52 |
+
self.cache: Dict[Optional[int],
|
| 53 |
+
Dict[int, InductorArtifact]] = defaultdict(dict)
|
| 54 |
+
self.disabled = disabled
|
| 55 |
+
self.cache_dir = cache_dir
|
| 56 |
+
self.cache_file_path = os.path.join(cache_dir,
|
| 57 |
+
"inductor_hash_cache.py")
|
| 58 |
+
if disabled:
|
| 59 |
+
return
|
| 60 |
+
# set flags so that Inductor and Triton store their cache
|
| 61 |
+
# in the cache_dir, then users only need to copy the cache_dir
|
| 62 |
+
# to another machine to reuse the cache.
|
| 63 |
+
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
| 64 |
+
os.makedirs(inductor_cache, exist_ok=True)
|
| 65 |
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
| 66 |
+
triton_cache = os.path.join(cache_dir, "triton_cache")
|
| 67 |
+
os.makedirs(triton_cache, exist_ok=True)
|
| 68 |
+
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
| 69 |
+
if os.path.exists(self.cache_file_path):
|
| 70 |
+
with open(self.cache_file_path) as f:
|
| 71 |
+
self.deserialize(f.read())
|
| 72 |
+
|
| 73 |
+
def deserialize(self, data: str):
|
| 74 |
+
# we use ast.literal_eval to parse the data
|
| 75 |
+
# because it is a safe way to parse Python literals.
|
| 76 |
+
# do not use eval(), it is unsafe.
|
| 77 |
+
list_data = ast.literal_eval(data)
|
| 78 |
+
for item in list_data:
|
| 79 |
+
runtime_shape = item[0]
|
| 80 |
+
graph_index = item[1]
|
| 81 |
+
hash_str = item[2]
|
| 82 |
+
# for compatibility of old version,
|
| 83 |
+
# where we don't have file_path.
|
| 84 |
+
# NOTE: after running the new code, the file_path
|
| 85 |
+
# will be updated.
|
| 86 |
+
file_path = "" if len(item) == 3 else item[3]
|
| 87 |
+
self.cache[runtime_shape][graph_index] = InductorArtifact(
|
| 88 |
+
hash_str=hash_str, file_path=file_path)
|
| 89 |
+
|
| 90 |
+
def serialize(self) -> str:
|
| 91 |
+
data = []
|
| 92 |
+
for runtime_shape, value in self.cache.items():
|
| 93 |
+
for graph_index, inductor_artifact in value.items():
|
| 94 |
+
data.append(
|
| 95 |
+
(runtime_shape, graph_index, inductor_artifact.hash_str,
|
| 96 |
+
inductor_artifact.file_path))
|
| 97 |
+
printer = pprint.PrettyPrinter(indent=4)
|
| 98 |
+
return printer.pformat(data)
|
| 99 |
+
|
| 100 |
+
def save_to_file(self):
|
| 101 |
+
if self.disabled:
|
| 102 |
+
return
|
| 103 |
+
with open(self.cache_file_path, "w") as f:
|
| 104 |
+
f.write(self.serialize())
|
| 105 |
+
|
| 106 |
+
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
|
| 107 |
+
if self.disabled:
|
| 108 |
+
return False
|
| 109 |
+
runtime_shape, graph_index = key
|
| 110 |
+
return runtime_shape in self.cache and graph_index in self.cache[
|
| 111 |
+
runtime_shape]
|
| 112 |
+
|
| 113 |
+
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
|
| 114 |
+
if self.disabled:
|
| 115 |
+
raise KeyError("cannot read from disabled cache")
|
| 116 |
+
runtime_shape, graph_index = key
|
| 117 |
+
return self.cache[runtime_shape][graph_index]
|
| 118 |
+
|
| 119 |
+
def __setitem__(self, key: Tuple[Optional[int], int],
|
| 120 |
+
value: InductorArtifact):
|
| 121 |
+
# setitem for disabled cache is fine, because we
|
| 122 |
+
# don't actually write to the disk
|
| 123 |
+
runtime_shape, graph_index = key
|
| 124 |
+
self.cache[runtime_shape][graph_index] = value
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class AlwaysHitShapeEnv:
|
| 128 |
+
"""
|
| 129 |
+
Why do we need this class:
|
| 130 |
+
|
| 131 |
+
For normal `torch.compile` usage, every compilation will have
|
| 132 |
+
one Dynamo bytecode compilation and one Inductor compilation.
|
| 133 |
+
The Inductor compilation happens under the context of the
|
| 134 |
+
Dynamo bytecode compilation, and that context is used to
|
| 135 |
+
determine the dynamic shape information, etc.
|
| 136 |
+
|
| 137 |
+
For our use case, we only run Dynamo bytecode compilation once,
|
| 138 |
+
and run Inductor compilation multiple times with different shapes
|
| 139 |
+
plus a general shape. The compilation for specific shapes happens
|
| 140 |
+
outside of the context of the Dynamo bytecode compilation. At that
|
| 141 |
+
time, we don't have shape environment to provide to Inductor, and
|
| 142 |
+
it will fail the Inductor code cache lookup.
|
| 143 |
+
|
| 144 |
+
By providing a dummy shape environment that always hits, we can
|
| 145 |
+
make the Inductor code cache lookup always hit, and we can
|
| 146 |
+
compile the graph for different shapes as needed.
|
| 147 |
+
|
| 148 |
+
The following dummy methods are obtained by trial-and-error
|
| 149 |
+
until it works.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self) -> None:
|
| 153 |
+
self.guards: List[Any] = []
|
| 154 |
+
|
| 155 |
+
def evaluate_guards_expression(self, *args, **kwargs):
|
| 156 |
+
return True
|
| 157 |
+
|
| 158 |
+
def get_pruned_guards(self, *args, **kwargs):
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
def produce_guards_expression(self, *args, **kwargs):
|
| 162 |
+
return ""
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def wrap_inductor(graph: fx.GraphModule,
|
| 166 |
+
example_inputs,
|
| 167 |
+
additional_inductor_config,
|
| 168 |
+
compilation_config: CompilationConfig,
|
| 169 |
+
vllm_backend: "VllmBackend",
|
| 170 |
+
graph_index: int = 0,
|
| 171 |
+
num_graphs: int = 1,
|
| 172 |
+
runtime_shape: Optional[int] = None,
|
| 173 |
+
use_inductor: bool = True) -> Any:
|
| 174 |
+
if graph_index == 0:
|
| 175 |
+
# before compiling the first graph, record the start time
|
| 176 |
+
global compilation_start_time
|
| 177 |
+
compilation_start_time = time.time()
|
| 178 |
+
|
| 179 |
+
if not use_inductor:
|
| 180 |
+
return graph
|
| 181 |
+
|
| 182 |
+
compilation_counter.num_inductor_compilations += 1
|
| 183 |
+
|
| 184 |
+
from torch._inductor import config
|
| 185 |
+
current_config = config.get_config_copy()
|
| 186 |
+
from torch._inductor.compile_fx import compile_fx
|
| 187 |
+
|
| 188 |
+
if additional_inductor_config is not None:
|
| 189 |
+
current_config.update(additional_inductor_config)
|
| 190 |
+
|
| 191 |
+
if isinstance(runtime_shape, int):
|
| 192 |
+
# for a specific batchsize, tuning triton kernel parameters
|
| 193 |
+
# can be beneficial
|
| 194 |
+
current_config["max_autotune"] = True
|
| 195 |
+
current_config["coordinate_descent_tuning"] = True
|
| 196 |
+
|
| 197 |
+
# inductor can inplace modify the graph, so we need to copy it
|
| 198 |
+
# see https://github.com/pytorch/pytorch/issues/138980
|
| 199 |
+
graph = copy.deepcopy(graph)
|
| 200 |
+
|
| 201 |
+
cache_data = vllm_backend.inductor_hash_cache
|
| 202 |
+
if (runtime_shape, graph_index) in cache_data:
|
| 203 |
+
# we compiled this graph before
|
| 204 |
+
# so we can directly lookup the compiled graph via hash
|
| 205 |
+
inductor_artifact = cache_data[(runtime_shape, graph_index)]
|
| 206 |
+
hash_str = inductor_artifact.hash_str
|
| 207 |
+
if graph_index == 0:
|
| 208 |
+
# adds some info logging for the first graph
|
| 209 |
+
logger.info(
|
| 210 |
+
"Directly lookup the graph for shape %s from the cache",
|
| 211 |
+
str(runtime_shape)) # noqa
|
| 212 |
+
logger.debug(
|
| 213 |
+
"directly lookup the %s-th graph for shape %s via hash %s",
|
| 214 |
+
graph_index, str(runtime_shape), hash_str)
|
| 215 |
+
from torch._inductor.codecache import FxGraphCache
|
| 216 |
+
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
| 217 |
+
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
| 218 |
+
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
| 219 |
+
hash_str, example_inputs, True, False)
|
| 220 |
+
assert inductor_compiled_graph is not None, (
|
| 221 |
+
"Inductor cache lookup failed. Please remove"
|
| 222 |
+
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
| 223 |
+
)
|
| 224 |
+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
| 225 |
+
|
| 226 |
+
# Inductor calling convention (function signature):
|
| 227 |
+
# f(list) -> tuple
|
| 228 |
+
# Dynamo calling convention (function signature):
|
| 229 |
+
# f(*args) -> Any
|
| 230 |
+
|
| 231 |
+
# need to know if the graph returns a tuple
|
| 232 |
+
from torch._inductor.compile_fx import graph_returns_tuple
|
| 233 |
+
returns_tuple = graph_returns_tuple(graph)
|
| 234 |
+
|
| 235 |
+
# this is the callable we return to Dynamo to run
|
| 236 |
+
def compiled_graph(*args):
|
| 237 |
+
# convert args to list
|
| 238 |
+
list_args = list(args)
|
| 239 |
+
graph_output = inductor_compiled_graph(list_args)
|
| 240 |
+
# unpack the tuple if needed
|
| 241 |
+
if returns_tuple:
|
| 242 |
+
return graph_output
|
| 243 |
+
else:
|
| 244 |
+
return graph_output[0]
|
| 245 |
+
else:
|
| 246 |
+
# it's the first time we compile this graph
|
| 247 |
+
# the assumption is that we don't have nested Inductor compilation.
|
| 248 |
+
# compiled_fx_graph_hash will only be called once, and we can hook
|
| 249 |
+
# it to get the hash of the compiled graph directly.
|
| 250 |
+
|
| 251 |
+
inductor_artifact = InductorArtifact()
|
| 252 |
+
from torch._inductor.codecache import (FxGraphCache,
|
| 253 |
+
compiled_fx_graph_hash)
|
| 254 |
+
original_load = FxGraphCache.load
|
| 255 |
+
|
| 256 |
+
def hijack_load(*args, **kwargs):
|
| 257 |
+
inductor_compiled_graph = original_load(*args, **kwargs)
|
| 258 |
+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
| 259 |
+
return inductor_compiled_graph
|
| 260 |
+
|
| 261 |
+
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
| 262 |
+
out = compiled_fx_graph_hash(*args, **kwargs)
|
| 263 |
+
inductor_artifact.hash_str = out[0]
|
| 264 |
+
return out
|
| 265 |
+
|
| 266 |
+
def _check_can_cache(*args, **kwargs):
|
| 267 |
+
# no error means it can be cached.
|
| 268 |
+
# Inductor refuses to cache the graph outside of Dynamo
|
| 269 |
+
# tracing context, and also disables caching for graphs
|
| 270 |
+
# with high-order ops.
|
| 271 |
+
# For vLLM, in either case, we want to cache the graph.
|
| 272 |
+
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
def _get_shape_env() -> AlwaysHitShapeEnv:
|
| 276 |
+
return AlwaysHitShapeEnv()
|
| 277 |
+
|
| 278 |
+
with ExitStack() as stack:
|
| 279 |
+
if not cache_data.disabled:
|
| 280 |
+
# compilation cache is enabled, patch several functions
|
| 281 |
+
|
| 282 |
+
# hijack to get the compiled graph itself
|
| 283 |
+
stack.enter_context(
|
| 284 |
+
patch("torch._inductor.codecache.FxGraphCache.load",
|
| 285 |
+
hijack_load))
|
| 286 |
+
|
| 287 |
+
# for hijacking the hash of the compiled graph
|
| 288 |
+
stack.enter_context(
|
| 289 |
+
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
| 290 |
+
hijack_compiled_fx_graph_hash))
|
| 291 |
+
|
| 292 |
+
# for providing a dummy shape environment
|
| 293 |
+
stack.enter_context(
|
| 294 |
+
patch(
|
| 295 |
+
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
| 296 |
+
_get_shape_env))
|
| 297 |
+
|
| 298 |
+
# for forcing the graph to be cached
|
| 299 |
+
stack.enter_context(
|
| 300 |
+
patch(
|
| 301 |
+
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
| 302 |
+
_check_can_cache))
|
| 303 |
+
|
| 304 |
+
compiled_graph = compile_fx(graph,
|
| 305 |
+
example_inputs,
|
| 306 |
+
config_patches=current_config)
|
| 307 |
+
# store the inductor_artifact in the cache
|
| 308 |
+
cache_data[(runtime_shape, graph_index)] = inductor_artifact
|
| 309 |
+
if graph_index == 0:
|
| 310 |
+
# adds some info logging for the first graph
|
| 311 |
+
logger.info("Cache the graph of shape %s for later use",
|
| 312 |
+
str(runtime_shape))
|
| 313 |
+
logger.debug(
|
| 314 |
+
"store the %s-th graph for shape %s via hash %s from file %s",
|
| 315 |
+
graph_index, str(runtime_shape), inductor_artifact.hash_str,
|
| 316 |
+
inductor_artifact.file_path)
|
| 317 |
+
# after compiling the last graph, record the end time
|
| 318 |
+
if graph_index == num_graphs - 1:
|
| 319 |
+
now = time.time()
|
| 320 |
+
elapsed = now - compilation_start_time
|
| 321 |
+
compilation_config.compilation_time += elapsed
|
| 322 |
+
if runtime_shape is None:
|
| 323 |
+
logger.info("Compiling a graph for general shape takes %.2f s",
|
| 324 |
+
elapsed)
|
| 325 |
+
else:
|
| 326 |
+
logger.info("Compiling a graph for shape %s takes %.2f s",
|
| 327 |
+
runtime_shape, elapsed)
|
| 328 |
+
|
| 329 |
+
return compiled_graph
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@dataclasses.dataclass
|
| 333 |
+
class SplitItem:
|
| 334 |
+
submod_name: str
|
| 335 |
+
graph_id: int
|
| 336 |
+
is_splitting_graph: bool
|
| 337 |
+
graph: fx.GraphModule
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def split_graph(graph: fx.GraphModule,
|
| 341 |
+
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
|
| 342 |
+
# split graph by ops
|
| 343 |
+
subgraph_id = 0
|
| 344 |
+
node_to_subgraph_id = {}
|
| 345 |
+
split_op_graphs = []
|
| 346 |
+
for node in graph.graph.nodes:
|
| 347 |
+
if node.op in ("output", "placeholder"):
|
| 348 |
+
continue
|
| 349 |
+
if node.op == 'call_function' and str(node.target) in ops:
|
| 350 |
+
subgraph_id += 1
|
| 351 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 352 |
+
split_op_graphs.append(subgraph_id)
|
| 353 |
+
subgraph_id += 1
|
| 354 |
+
else:
|
| 355 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 356 |
+
|
| 357 |
+
# `keep_original_order` is important!
|
| 358 |
+
# otherwise pytorch might reorder the nodes and
|
| 359 |
+
# the semantics of the graph will change when we
|
| 360 |
+
# have mutations in the graph
|
| 361 |
+
split_gm = torch.fx.passes.split_module.split_module(
|
| 362 |
+
graph,
|
| 363 |
+
None,
|
| 364 |
+
lambda node: node_to_subgraph_id[node],
|
| 365 |
+
keep_original_order=True)
|
| 366 |
+
|
| 367 |
+
outputs = []
|
| 368 |
+
|
| 369 |
+
names = [name for (name, module) in split_gm.named_modules()]
|
| 370 |
+
|
| 371 |
+
for name in names:
|
| 372 |
+
if "." in name or name == "":
|
| 373 |
+
# recursive child module or the root module
|
| 374 |
+
continue
|
| 375 |
+
|
| 376 |
+
module = getattr(split_gm, name)
|
| 377 |
+
|
| 378 |
+
graph_id = int(name.replace("submod_", ""))
|
| 379 |
+
outputs.append(
|
| 380 |
+
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
| 381 |
+
|
| 382 |
+
# sort by intetger graph_id, rather than string name
|
| 383 |
+
outputs.sort(key=lambda x: x.graph_id)
|
| 384 |
+
|
| 385 |
+
return split_gm, outputs
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# we share the global graph pool among all the backends
|
| 389 |
+
global_graph_pool = None
|
| 390 |
+
|
| 391 |
+
compilation_start_time = 0.0
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
| 395 |
+
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
| 396 |
+
It runs the given graph with fake inputs, and compile some
|
| 397 |
+
submodules specified by `compile_submod_names` with the given
|
| 398 |
+
compilation configs.
|
| 399 |
+
|
| 400 |
+
NOTE: the order in `compile_submod_names` matters, because
|
| 401 |
+
it will be used to determine the order of the compiled piecewise
|
| 402 |
+
graphs. The first graph will handle logging, and the last graph
|
| 403 |
+
has some special cudagraph output handling.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(self, module: torch.fx.GraphModule,
|
| 407 |
+
compile_submod_names: List[str], vllm_config: VllmConfig,
|
| 408 |
+
graph_pool, vllm_backend: "VllmBackend"):
|
| 409 |
+
super().__init__(module)
|
| 410 |
+
from torch._guards import detect_fake_mode
|
| 411 |
+
self.fake_mode = detect_fake_mode()
|
| 412 |
+
self.compile_submod_names = compile_submod_names
|
| 413 |
+
self.compilation_config = vllm_config.compilation_config
|
| 414 |
+
self.graph_pool = graph_pool
|
| 415 |
+
self.vllm_config = vllm_config
|
| 416 |
+
self.vllm_backend = vllm_backend
|
| 417 |
+
|
| 418 |
+
def run(self, *args):
|
| 419 |
+
fake_args = [
|
| 420 |
+
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
| 421 |
+
for t in args
|
| 422 |
+
]
|
| 423 |
+
with self.fake_mode:
|
| 424 |
+
return super().run(*fake_args)
|
| 425 |
+
|
| 426 |
+
def call_module(self, target: torch.fx.node.Target,
|
| 427 |
+
args: Tuple[torch.fx.node.Argument,
|
| 428 |
+
...], kwargs: Dict[str, Any]) -> Any:
|
| 429 |
+
assert isinstance(target, str)
|
| 430 |
+
output = super().call_module(target, args, kwargs)
|
| 431 |
+
|
| 432 |
+
if target in self.compile_submod_names:
|
| 433 |
+
index = self.compile_submod_names.index(target)
|
| 434 |
+
submod = self.fetch_attr(target)
|
| 435 |
+
sym_shape_indices = [
|
| 436 |
+
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
| 437 |
+
]
|
| 438 |
+
global compilation_start_time
|
| 439 |
+
compiled_graph_for_general_shape = wrap_inductor(
|
| 440 |
+
submod,
|
| 441 |
+
args,
|
| 442 |
+
self.compilation_config.inductor_compile_config,
|
| 443 |
+
self.compilation_config,
|
| 444 |
+
self.vllm_backend,
|
| 445 |
+
graph_index=index,
|
| 446 |
+
num_graphs=len(self.compile_submod_names),
|
| 447 |
+
runtime_shape=None,
|
| 448 |
+
use_inductor=self.compilation_config.use_inductor)
|
| 449 |
+
|
| 450 |
+
self.module.__dict__[target] = PiecewiseBackend(
|
| 451 |
+
submod, self.vllm_config, self.graph_pool, index,
|
| 452 |
+
len(self.compile_submod_names), sym_shape_indices,
|
| 453 |
+
compiled_graph_for_general_shape, self.vllm_backend)
|
| 454 |
+
|
| 455 |
+
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
| 456 |
+
|
| 457 |
+
return output
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class VllmBackend:
|
| 461 |
+
"""The compilation backend for `torch.compile` with VLLM.
|
| 462 |
+
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
| 463 |
+
where we customize the compilation.
|
| 464 |
+
|
| 465 |
+
The major work of this backend is to split the graph into
|
| 466 |
+
piecewise graphs, and pass them to the piecewise backend.
|
| 467 |
+
|
| 468 |
+
This backend also adds the PostGradPassManager to Inductor config,
|
| 469 |
+
which handles the post-grad passes.
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
vllm_config: VllmConfig
|
| 473 |
+
compilation_config: CompilationConfig
|
| 474 |
+
graph_pool: Any
|
| 475 |
+
_called: bool = False
|
| 476 |
+
# the graph we compiled
|
| 477 |
+
graph: fx.GraphModule
|
| 478 |
+
# the stiching graph module for all the piecewise graphs
|
| 479 |
+
split_gm: fx.GraphModule
|
| 480 |
+
piecewise_graphs: List[SplitItem]
|
| 481 |
+
returned_callable: Callable
|
| 482 |
+
# Inductor passes to run on the graph pre-defunctionalization
|
| 483 |
+
post_grad_passes: Sequence[Callable]
|
| 484 |
+
sym_tensor_indices: List[int]
|
| 485 |
+
input_buffers: List[torch.Tensor]
|
| 486 |
+
inductor_hash_cache: InductorHashCache
|
| 487 |
+
|
| 488 |
+
def __init__(
|
| 489 |
+
self,
|
| 490 |
+
vllm_config: VllmConfig,
|
| 491 |
+
):
|
| 492 |
+
global global_graph_pool
|
| 493 |
+
if global_graph_pool is None:
|
| 494 |
+
global_graph_pool = torch.cuda.graph_pool_handle()
|
| 495 |
+
|
| 496 |
+
# TODO: in the future, if we want to use multiple
|
| 497 |
+
# streams, it might not be safe to share a global pool.
|
| 498 |
+
# only investigate this when we use multiple streams
|
| 499 |
+
self.graph_pool = global_graph_pool
|
| 500 |
+
|
| 501 |
+
# Passes to run on the graph post-grad.
|
| 502 |
+
self.post_grad_pass_manager = PostGradPassManager()
|
| 503 |
+
|
| 504 |
+
self.sym_tensor_indices = []
|
| 505 |
+
self.input_buffers = []
|
| 506 |
+
|
| 507 |
+
self.vllm_config = vllm_config
|
| 508 |
+
self.compilation_config = vllm_config.compilation_config
|
| 509 |
+
|
| 510 |
+
# `torch.compile` is JIT compiled, so we don't need to
|
| 511 |
+
# do anything here
|
| 512 |
+
|
| 513 |
+
def configure_post_pass(self):
|
| 514 |
+
config = self.compilation_config
|
| 515 |
+
self.post_grad_pass_manager.configure(config.pass_config)
|
| 516 |
+
|
| 517 |
+
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
| 518 |
+
# hook. If a pass for that hook exists, add it to the pass manager.
|
| 519 |
+
inductor_config = config.inductor_compile_config
|
| 520 |
+
PASS_KEY = "post_grad_custom_post_pass"
|
| 521 |
+
if PASS_KEY in inductor_config:
|
| 522 |
+
# Config should automatically wrap all inductor passes
|
| 523 |
+
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
| 524 |
+
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
| 525 |
+
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
| 526 |
+
|
| 527 |
+
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
| 528 |
+
|
| 529 |
+
vllm_config = self.vllm_config
|
| 530 |
+
if not self.compilation_config.cache_dir:
|
| 531 |
+
# no provided cache dir, generate one based on the known factors
|
| 532 |
+
# that affects the compilation. if none of the factors change,
|
| 533 |
+
# the cache dir will be the same so that we can reuse the compiled
|
| 534 |
+
# graph.
|
| 535 |
+
|
| 536 |
+
# 1. factors come from the vllm_config (it mainly summarizes how the
|
| 537 |
+
# model is created)
|
| 538 |
+
config_hash = vllm_config.compute_hash()
|
| 539 |
+
|
| 540 |
+
# 2. factors come from the code files that are traced by Dynamo (
|
| 541 |
+
# it mainly summarizes how the model is used in forward pass)
|
| 542 |
+
forward_code_files = list(
|
| 543 |
+
sorted(self.compilation_config.traced_files))
|
| 544 |
+
self.compilation_config.traced_files.clear()
|
| 545 |
+
logger.debug(
|
| 546 |
+
"Traced files (to be considered for compilation cache):\n%s",
|
| 547 |
+
"\n".join(forward_code_files))
|
| 548 |
+
hash_content = []
|
| 549 |
+
for filepath in forward_code_files:
|
| 550 |
+
hash_content.append(filepath)
|
| 551 |
+
with open(filepath) as f:
|
| 552 |
+
hash_content.append(f.read())
|
| 553 |
+
import hashlib
|
| 554 |
+
code_hash = hashlib.md5(
|
| 555 |
+
"\n".join(hash_content).encode()).hexdigest()
|
| 556 |
+
|
| 557 |
+
# combine the two hashes to generate the cache dir
|
| 558 |
+
hash_key = hashlib.md5(
|
| 559 |
+
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
| 560 |
+
cache_dir = os.path.join(
|
| 561 |
+
envs.VLLM_CACHE_ROOT,
|
| 562 |
+
"torch_compile_cache",
|
| 563 |
+
hash_key,
|
| 564 |
+
)
|
| 565 |
+
self.compilation_config.cache_dir = cache_dir
|
| 566 |
+
|
| 567 |
+
cache_dir = self.compilation_config.cache_dir
|
| 568 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 569 |
+
local_cache_dir = os.path.join(
|
| 570 |
+
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
| 571 |
+
self.compilation_config.local_cache_dir = local_cache_dir
|
| 572 |
+
|
| 573 |
+
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
| 574 |
+
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
| 575 |
+
local_cache_dir, disabled=disabled)
|
| 576 |
+
if disabled:
|
| 577 |
+
logger.info("vLLM's torch.compile cache is disabled.")
|
| 578 |
+
else:
|
| 579 |
+
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
| 580 |
+
local_cache_dir)
|
| 581 |
+
|
| 582 |
+
# when dynamo calls the backend, it means the bytecode
|
| 583 |
+
# transform and analysis are done
|
| 584 |
+
compilation_counter.num_graphs_seen += 1
|
| 585 |
+
from .monitor import torch_compile_start_time
|
| 586 |
+
dynamo_time = time.time() - torch_compile_start_time
|
| 587 |
+
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
| 588 |
+
self.compilation_config.compilation_time += dynamo_time
|
| 589 |
+
|
| 590 |
+
# we control the compilation process, each instance can only be
|
| 591 |
+
# called once
|
| 592 |
+
assert not self._called, "VllmBackend can only be called once"
|
| 593 |
+
|
| 594 |
+
self.graph = graph
|
| 595 |
+
self.configure_post_pass()
|
| 596 |
+
|
| 597 |
+
self.split_gm, self.piecewise_graphs = split_graph(
|
| 598 |
+
graph, self.compilation_config.splitting_ops)
|
| 599 |
+
|
| 600 |
+
from torch._dynamo.utils import lazy_format_graph_code
|
| 601 |
+
|
| 602 |
+
# depyf will hook lazy_format_graph_code and dump the graph
|
| 603 |
+
# for debugging, no need to print the graph here
|
| 604 |
+
lazy_format_graph_code("before split", self.graph)
|
| 605 |
+
lazy_format_graph_code("after split", self.split_gm)
|
| 606 |
+
|
| 607 |
+
compilation_counter.num_piecewise_graphs_seen += len(
|
| 608 |
+
self.piecewise_graphs)
|
| 609 |
+
submod_names_to_compile = [
|
| 610 |
+
item.submod_name for item in self.piecewise_graphs
|
| 611 |
+
if not item.is_splitting_graph
|
| 612 |
+
]
|
| 613 |
+
|
| 614 |
+
# propagate the split graph to the piecewise backend,
|
| 615 |
+
# compile submodules with symbolic shapes
|
| 616 |
+
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
| 617 |
+
self.vllm_config, self.graph_pool,
|
| 618 |
+
self).run(*example_inputs)
|
| 619 |
+
|
| 620 |
+
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
| 621 |
+
if not os.path.exists(graph_path):
|
| 622 |
+
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
| 623 |
+
# use `print_readable` because it can include submodules
|
| 624 |
+
src = "from __future__ import annotations\nimport torch\n" + \
|
| 625 |
+
self.split_gm.print_readable(print_output=False)
|
| 626 |
+
src = src.replace("<lambda>", "GraphModule")
|
| 627 |
+
with open(graph_path, "w") as f:
|
| 628 |
+
f.write(src)
|
| 629 |
+
|
| 630 |
+
logger.debug("Computation graph saved to %s", graph_path)
|
| 631 |
+
|
| 632 |
+
self._called = True
|
| 633 |
+
|
| 634 |
+
if not self.compilation_config.use_cudagraph or \
|
| 635 |
+
not self.compilation_config.cudagraph_copy_inputs:
|
| 636 |
+
return self.split_gm
|
| 637 |
+
|
| 638 |
+
# if we need to copy input buffers for cudagraph
|
| 639 |
+
from torch._guards import detect_fake_mode
|
| 640 |
+
fake_mode = detect_fake_mode()
|
| 641 |
+
fake_args = [
|
| 642 |
+
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
| 643 |
+
for t in example_inputs
|
| 644 |
+
]
|
| 645 |
+
|
| 646 |
+
# index of tensors that have symbolic shapes (batch size)
|
| 647 |
+
# for weights and static buffers, they will have concrete shapes.
|
| 648 |
+
# symbolic shape only happens for input tensors.
|
| 649 |
+
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
| 650 |
+
self.sym_tensor_indices = [
|
| 651 |
+
i for i, x in enumerate(fake_args)
|
| 652 |
+
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
|
| 653 |
+
any(is_symbolic(d) for d in x.size())
|
| 654 |
+
]
|
| 655 |
+
|
| 656 |
+
# compiler managed cudagraph input buffers
|
| 657 |
+
# we assume the first run with symbolic shapes
|
| 658 |
+
# has the maximum size among all the tensors
|
| 659 |
+
self.input_buffers = [
|
| 660 |
+
example_inputs[x].clone() for x in self.sym_tensor_indices
|
| 661 |
+
]
|
| 662 |
+
|
| 663 |
+
# this is the callable we return to Dynamo to run
|
| 664 |
+
def copy_and_call(*args):
|
| 665 |
+
list_args = list(args)
|
| 666 |
+
for i, index in enumerate(self.sym_tensor_indices):
|
| 667 |
+
runtime_tensor = list_args[index]
|
| 668 |
+
runtime_shape = runtime_tensor.shape[0]
|
| 669 |
+
static_tensor = self.input_buffers[i][:runtime_shape]
|
| 670 |
+
|
| 671 |
+
# copy the tensor to the static buffer
|
| 672 |
+
static_tensor.copy_(runtime_tensor)
|
| 673 |
+
|
| 674 |
+
# replace the tensor in the list_args to the static buffer
|
| 675 |
+
list_args[index] = static_tensor
|
| 676 |
+
return self.split_gm(*list_args)
|
| 677 |
+
|
| 678 |
+
return copy_and_call
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
@dataclasses.dataclass
|
| 682 |
+
class ConcreteSizeEntry:
|
| 683 |
+
runtime_shape: int
|
| 684 |
+
need_to_compile: bool # the size is in compile_sizes
|
| 685 |
+
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
| 686 |
+
|
| 687 |
+
compiled: bool = False
|
| 688 |
+
runnable: Callable = None # type: ignore
|
| 689 |
+
num_finished_warmup: int = 0
|
| 690 |
+
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
| 691 |
+
output: Optional[Any] = None
|
| 692 |
+
|
| 693 |
+
# for cudagraph debugging, track the input addresses
|
| 694 |
+
# during capture, and check if they are the same during replay
|
| 695 |
+
input_addresses: Optional[List[int]] = None
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
class PiecewiseBackend:
|
| 699 |
+
|
| 700 |
+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
| 701 |
+
graph_pool: Any, piecewise_compile_index: int,
|
| 702 |
+
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
| 703 |
+
compiled_graph_for_general_shape: Callable,
|
| 704 |
+
vllm_backend: VllmBackend):
|
| 705 |
+
"""
|
| 706 |
+
The backend for piecewise compilation.
|
| 707 |
+
It mainly handles the compilation and cudagraph capturing.
|
| 708 |
+
|
| 709 |
+
We will compile `self.graph` once for the general shape,
|
| 710 |
+
and then compile for different shapes specified in
|
| 711 |
+
`compilation_config.compile_sizes`.
|
| 712 |
+
|
| 713 |
+
Independently, we will capture cudagraph for different shapes.
|
| 714 |
+
|
| 715 |
+
If a shape needs both compilation and cudagraph, we will
|
| 716 |
+
compile it first, and then capture cudagraph.
|
| 717 |
+
"""
|
| 718 |
+
self.graph = graph
|
| 719 |
+
self.vllm_config = vllm_config
|
| 720 |
+
self.compilation_config = vllm_config.compilation_config
|
| 721 |
+
self.graph_pool = graph_pool
|
| 722 |
+
self.piecewise_compile_index = piecewise_compile_index
|
| 723 |
+
self.total_piecewise_compiles = total_piecewise_compiles
|
| 724 |
+
self.vllm_backend = vllm_backend
|
| 725 |
+
|
| 726 |
+
self.is_first_graph = piecewise_compile_index == 0
|
| 727 |
+
self.is_last_graph = (
|
| 728 |
+
piecewise_compile_index == total_piecewise_compiles - 1)
|
| 729 |
+
|
| 730 |
+
self.compile_sizes: Set[int] = set(
|
| 731 |
+
self.compilation_config.compile_sizes)
|
| 732 |
+
self.cudagraph_capture_sizes: Set[int] = set(
|
| 733 |
+
self.compilation_config.cudagraph_capture_sizes
|
| 734 |
+
) if self.compilation_config.use_cudagraph else set()
|
| 735 |
+
|
| 736 |
+
self.first_run_finished = False
|
| 737 |
+
|
| 738 |
+
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
| 739 |
+
|
| 740 |
+
self.sym_shape_indices = sym_shape_indices
|
| 741 |
+
|
| 742 |
+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
| 743 |
+
|
| 744 |
+
# the entries for different shapes that we need to either
|
| 745 |
+
# compile or capture cudagraph
|
| 746 |
+
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
| 747 |
+
|
| 748 |
+
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
| 749 |
+
# and updates during the compilation process, so we need to copy it
|
| 750 |
+
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
| 751 |
+
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
| 752 |
+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
| 753 |
+
runtime_shape=shape,
|
| 754 |
+
need_to_compile=shape in self.compile_sizes,
|
| 755 |
+
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
def check_for_ending_compilation(self):
|
| 759 |
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
| 760 |
+
# no specific sizes to compile
|
| 761 |
+
# save the hash of the inductor graph for the next run
|
| 762 |
+
self.vllm_backend.inductor_hash_cache.save_to_file()
|
| 763 |
+
end_monitoring_torch_compile(self.vllm_config)
|
| 764 |
+
|
| 765 |
+
def __call__(self, *args) -> Any:
|
| 766 |
+
if not self.first_run_finished:
|
| 767 |
+
self.first_run_finished = True
|
| 768 |
+
self.check_for_ending_compilation()
|
| 769 |
+
return self.compiled_graph_for_general_shape(*args)
|
| 770 |
+
|
| 771 |
+
runtime_shape = args[self.sym_shape_indices[0]]
|
| 772 |
+
if runtime_shape not in self.concrete_size_entries:
|
| 773 |
+
# we don't need to do anything for this shape
|
| 774 |
+
return self.compiled_graph_for_general_shape(*args)
|
| 775 |
+
|
| 776 |
+
entry = self.concrete_size_entries[runtime_shape]
|
| 777 |
+
|
| 778 |
+
if entry.runnable is None:
|
| 779 |
+
entry.runnable = self.compiled_graph_for_general_shape
|
| 780 |
+
|
| 781 |
+
if entry.need_to_compile and not entry.compiled:
|
| 782 |
+
entry.compiled = True
|
| 783 |
+
self.to_be_compiled_sizes.remove(runtime_shape)
|
| 784 |
+
# args are real arguments
|
| 785 |
+
entry.runnable = wrap_inductor(
|
| 786 |
+
self.graph,
|
| 787 |
+
args,
|
| 788 |
+
self.compilation_config.inductor_compile_config,
|
| 789 |
+
self.compilation_config,
|
| 790 |
+
self.vllm_backend,
|
| 791 |
+
graph_index=self.piecewise_compile_index,
|
| 792 |
+
num_graphs=self.total_piecewise_compiles,
|
| 793 |
+
runtime_shape=runtime_shape,
|
| 794 |
+
use_inductor=self.compilation_config.use_inductor)
|
| 795 |
+
|
| 796 |
+
# finished compilations for all required shapes
|
| 797 |
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
| 798 |
+
self.check_for_ending_compilation()
|
| 799 |
+
|
| 800 |
+
if not entry.use_cudagraph:
|
| 801 |
+
return entry.runnable(*args)
|
| 802 |
+
|
| 803 |
+
if entry.cudagraph is None:
|
| 804 |
+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
| 805 |
+
entry.num_finished_warmup += 1
|
| 806 |
+
if self.is_first_graph:
|
| 807 |
+
logger.debug(
|
| 808 |
+
"Warming up %s/%s for shape %s",
|
| 809 |
+
entry.num_finished_warmup,
|
| 810 |
+
self.compilation_config.cudagraph_num_of_warmups,
|
| 811 |
+
runtime_shape)
|
| 812 |
+
return entry.runnable(*args)
|
| 813 |
+
|
| 814 |
+
if self.is_first_graph:
|
| 815 |
+
# Since we capture cudagraph for many different shapes and
|
| 816 |
+
# capturing is fast, we don't need to log it for every shape.
|
| 817 |
+
# We only log it in the debug mode.
|
| 818 |
+
logger.debug("Capturing a cudagraph for shape %s",
|
| 819 |
+
runtime_shape)
|
| 820 |
+
|
| 821 |
+
input_addresses = [
|
| 822 |
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
| 823 |
+
]
|
| 824 |
+
entry.input_addresses = input_addresses
|
| 825 |
+
cudagraph = torch.cuda.CUDAGraph()
|
| 826 |
+
|
| 827 |
+
with ExitStack() as stack:
|
| 828 |
+
if not self.is_first_graph:
|
| 829 |
+
# during every model forward, we will capture
|
| 830 |
+
# many pieces of cudagraphs (roughly one per layer).
|
| 831 |
+
# running gc again and again across layers will
|
| 832 |
+
# make the cudagraph capture very slow.
|
| 833 |
+
# therefore, we only run gc for the first graph,
|
| 834 |
+
# and disable gc for the rest of the graphs.
|
| 835 |
+
stack.enter_context(patch("gc.collect", lambda: None))
|
| 836 |
+
stack.enter_context(
|
| 837 |
+
patch("torch.cuda.empty_cache", lambda: None))
|
| 838 |
+
|
| 839 |
+
# mind-exploding: carefully manage the reference and memory.
|
| 840 |
+
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
| 841 |
+
# `output` is managed by pytorch's cudagraph pool
|
| 842 |
+
output = entry.runnable(*args)
|
| 843 |
+
if self.is_last_graph:
|
| 844 |
+
# by converting it to weak ref,
|
| 845 |
+
# the original `output` will immediately be released
|
| 846 |
+
# to save memory. It is only safe to do this for
|
| 847 |
+
# the last graph, because the output of the last graph
|
| 848 |
+
# will not be used by any other cuda graph.
|
| 849 |
+
output = weak_ref_tensors(output)
|
| 850 |
+
|
| 851 |
+
# here we always use weak ref for the output
|
| 852 |
+
# to save memory
|
| 853 |
+
entry.output = weak_ref_tensors(output)
|
| 854 |
+
entry.cudagraph = cudagraph
|
| 855 |
+
|
| 856 |
+
compilation_counter.num_cudagraph_caputured += 1
|
| 857 |
+
|
| 858 |
+
# important: we need to return the output, rather than
|
| 859 |
+
# the weak ref of the output, so that pytorch can correctly
|
| 860 |
+
# manage the memory during cuda graph capture
|
| 861 |
+
return output
|
| 862 |
+
|
| 863 |
+
if self.is_debugging_mode:
|
| 864 |
+
# check if the input addresses are the same
|
| 865 |
+
new_input_addresses = [
|
| 866 |
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
| 867 |
+
]
|
| 868 |
+
assert new_input_addresses == entry.input_addresses, (
|
| 869 |
+
"Input addresses for cudagraphs are different during replay."
|
| 870 |
+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
entry.cudagraph.replay()
|
| 874 |
+
return entry.output
|
.venv/lib/python3.11/site-packages/vllm/compilation/counter.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import dataclasses
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclasses.dataclass
|
| 9 |
+
class CompilationCounter:
|
| 10 |
+
num_models_seen: int = 0
|
| 11 |
+
num_graphs_seen: int = 0
|
| 12 |
+
# including the splitting ops
|
| 13 |
+
num_piecewise_graphs_seen: int = 0
|
| 14 |
+
# not including the splitting ops
|
| 15 |
+
num_piecewise_capturable_graphs_seen: int = 0
|
| 16 |
+
num_inductor_compilations: int = 0
|
| 17 |
+
num_cudagraph_caputured: int = 0
|
| 18 |
+
|
| 19 |
+
def clone(self) -> "CompilationCounter":
|
| 20 |
+
return copy.deepcopy(self)
|
| 21 |
+
|
| 22 |
+
@contextmanager
|
| 23 |
+
def expect(self, **kwargs):
|
| 24 |
+
old = self.clone()
|
| 25 |
+
yield
|
| 26 |
+
for k, v in kwargs.items():
|
| 27 |
+
assert getattr(self, k) - getattr(old, k) == v, (
|
| 28 |
+
f"{k} not as expected, before it is {getattr(old, k)}"
|
| 29 |
+
f", after it is {getattr(self, k)}, "
|
| 30 |
+
f"expected diff is {v}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
compilation_counter = CompilationCounter()
|
.venv/lib/python3.11/site-packages/vllm/compilation/decorators.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
| 5 |
+
from unittest.mock import patch
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
| 10 |
+
|
| 11 |
+
from vllm.compilation.counter import compilation_counter
|
| 12 |
+
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
| 13 |
+
from vllm.config import CompilationLevel, VllmConfig
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.sequence import IntermediateTensors
|
| 16 |
+
from vllm.utils import supports_dynamo
|
| 17 |
+
|
| 18 |
+
from .monitor import start_monitoring_torch_compile
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
_T = TypeVar("_T", bound=type[nn.Module])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@overload
|
| 26 |
+
def support_torch_compile(
|
| 27 |
+
*,
|
| 28 |
+
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
|
| 29 |
+
) -> Callable[[_T], _T]:
|
| 30 |
+
...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@overload
|
| 34 |
+
def support_torch_compile(cls: _T) -> _T:
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def support_torch_compile(
|
| 39 |
+
cls: Optional[_T] = None,
|
| 40 |
+
*,
|
| 41 |
+
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
|
| 42 |
+
) -> Union[Callable[[_T], _T], _T]:
|
| 43 |
+
"""
|
| 44 |
+
A decorator to add support for compiling the forward method of a class.
|
| 45 |
+
|
| 46 |
+
Usage 1: use directly as a decorator without arguments:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
@support_torch_compile
|
| 50 |
+
class MyModel(nn.Module):
|
| 51 |
+
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
| 52 |
+
...
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Usage 2: use as a decorator with arguments:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
| 59 |
+
class MyModel(nn.Module):
|
| 60 |
+
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
| 61 |
+
...
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
| 65 |
+
dimensions of the argument. The dynamic dimensions can be either a single
|
| 66 |
+
integer or a list of integers.
|
| 67 |
+
|
| 68 |
+
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
| 69 |
+
of the `forward` method, based on the following default rules:
|
| 70 |
+
|
| 71 |
+
- if the argument is annotated as `torch.Tensor` or
|
| 72 |
+
`Optional[torch.Tensor]`, the first dimension will be
|
| 73 |
+
marked as dynamic.
|
| 74 |
+
- if the argument is annotated as `IntermediateTensors`, the first
|
| 75 |
+
dimension of all the tensors in the intermediate tensors
|
| 76 |
+
will be marked as dynamic.
|
| 77 |
+
|
| 78 |
+
During runtime, when we actually mark dimensions of tensors,
|
| 79 |
+
it depends on the value of arguments:
|
| 80 |
+
|
| 81 |
+
- if it is a single integer (can be negative), the corresponding dimension
|
| 82 |
+
of the argument will be marked as dynamic.
|
| 83 |
+
- if it is `None`, ignored.
|
| 84 |
+
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
| 85 |
+
tensors will be marked as dynamic.
|
| 86 |
+
- otherwise, it will raise an error.
|
| 87 |
+
|
| 88 |
+
NOTE: if an argument is `None`, it should always be passed as `None` during
|
| 89 |
+
the lifetime of the model, otherwise, it cannot be captured as a single
|
| 90 |
+
computation graph.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def cls_decorator_helper(cls: _T) -> _T:
|
| 94 |
+
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
| 95 |
+
# to avoid too much indentation for `_support_torch_compile``
|
| 96 |
+
if not hasattr(cls, 'forward'):
|
| 97 |
+
raise TypeError("decorated class should have a forward method.")
|
| 98 |
+
sig = inspect.signature(cls.forward)
|
| 99 |
+
inferred_dynamic_arg_dims = dynamic_arg_dims
|
| 100 |
+
if inferred_dynamic_arg_dims is None:
|
| 101 |
+
inferred_dynamic_arg_dims = {}
|
| 102 |
+
for k, v in sig.parameters.items():
|
| 103 |
+
if v.annotation in [
|
| 104 |
+
torch.Tensor, Optional[torch.Tensor],
|
| 105 |
+
IntermediateTensors, Optional[IntermediateTensors]
|
| 106 |
+
]:
|
| 107 |
+
inferred_dynamic_arg_dims[k] = 0
|
| 108 |
+
|
| 109 |
+
logger.debug(("Inferred dynamic dimensions for "
|
| 110 |
+
"forward method of %s: %s"), cls,
|
| 111 |
+
list(inferred_dynamic_arg_dims.keys()))
|
| 112 |
+
|
| 113 |
+
if len(inferred_dynamic_arg_dims) == 0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"No dynamic dimensions found in the forward method of "
|
| 116 |
+
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
| 117 |
+
|
| 118 |
+
for k in inferred_dynamic_arg_dims:
|
| 119 |
+
if k not in sig.parameters:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Argument {k} not found in the forward method of {cls}")
|
| 122 |
+
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
|
| 123 |
+
|
| 124 |
+
if cls is not None:
|
| 125 |
+
# use `support_torch_compile` as a decorator without arguments
|
| 126 |
+
assert isinstance(cls, type)
|
| 127 |
+
return cls_decorator_helper(cls)
|
| 128 |
+
|
| 129 |
+
return cls_decorator_helper
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _support_torch_compile(
|
| 133 |
+
cls: _T,
|
| 134 |
+
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
|
| 135 |
+
) -> _T:
|
| 136 |
+
"""
|
| 137 |
+
A decorator to add support for compiling the forward method of a class.
|
| 138 |
+
"""
|
| 139 |
+
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
| 140 |
+
# support decorating multiple times
|
| 141 |
+
return cls
|
| 142 |
+
|
| 143 |
+
# take care of method resolution order
|
| 144 |
+
# make sure super().__init__ is called on the base class
|
| 145 |
+
# other than TorchCompileWrapperWithCustomDispatcher
|
| 146 |
+
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
| 147 |
+
|
| 148 |
+
old_init = cls.__init__
|
| 149 |
+
|
| 150 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
| 151 |
+
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
| 152 |
+
self.vllm_config = vllm_config
|
| 153 |
+
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
| 154 |
+
# will handle the compilation, so we don't need to do anything here.
|
| 155 |
+
self.do_not_compile = \
|
| 156 |
+
vllm_config.compilation_config.level in [
|
| 157 |
+
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
| 158 |
+
] or not supports_dynamo()
|
| 159 |
+
if self.do_not_compile:
|
| 160 |
+
return
|
| 161 |
+
compilation_counter.num_models_seen += 1
|
| 162 |
+
TorchCompileWrapperWithCustomDispatcher.__init__(
|
| 163 |
+
self, compilation_level=vllm_config.compilation_config.level)
|
| 164 |
+
|
| 165 |
+
cls.__init__ = __init__
|
| 166 |
+
|
| 167 |
+
def __call__(self, *args, **kwargs):
|
| 168 |
+
# torch.compiler.is_compiling() means we are inside the compilation
|
| 169 |
+
# e.g. TPU has the compilation logic in model runner, so we don't
|
| 170 |
+
# need to compile the model inside.
|
| 171 |
+
if self.do_not_compile or torch.compiler.is_compiling():
|
| 172 |
+
return self.forward(*args, **kwargs)
|
| 173 |
+
|
| 174 |
+
# the first compilation needs to have dynamic shapes marked
|
| 175 |
+
if len(self.compiled_codes) < 1:
|
| 176 |
+
sig = inspect.signature(self.__class__.forward)
|
| 177 |
+
bound_args = sig.bind(self, *args, **kwargs)
|
| 178 |
+
bound_args.apply_defaults()
|
| 179 |
+
for k, dims in dynamic_arg_dims.items():
|
| 180 |
+
arg = bound_args.arguments.get(k)
|
| 181 |
+
if arg is not None:
|
| 182 |
+
dims = [dims] if isinstance(dims, int) else dims
|
| 183 |
+
if isinstance(arg, torch.Tensor):
|
| 184 |
+
# In case dims is specified with negative indexing
|
| 185 |
+
dims = [
|
| 186 |
+
arg.ndim + dim if dim < 0 else dim for dim in dims
|
| 187 |
+
]
|
| 188 |
+
torch._dynamo.mark_dynamic(arg, dims)
|
| 189 |
+
elif isinstance(arg, IntermediateTensors):
|
| 190 |
+
for tensor in arg.tensors.values():
|
| 191 |
+
# In case dims is specified with negative indexing
|
| 192 |
+
dims = [
|
| 193 |
+
tensor.ndim + dim if dim < 0 else dim
|
| 194 |
+
for dim in dims
|
| 195 |
+
]
|
| 196 |
+
torch._dynamo.mark_dynamic(tensor, dims)
|
| 197 |
+
else:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"Unsupported dynamic dimensions"
|
| 200 |
+
f" {dims} for argument {k} with type {type(arg)}.")
|
| 201 |
+
# here, it is the starting point of the `torch.compile` process
|
| 202 |
+
start_monitoring_torch_compile(self.vllm_config)
|
| 203 |
+
logger.debug("Start compiling function %s",
|
| 204 |
+
self.original_code_object)
|
| 205 |
+
|
| 206 |
+
# if we don't use custom dispatcher, we can directly call the
|
| 207 |
+
# compiled function and let torch.compile handle the dispatching,
|
| 208 |
+
# with the overhead of guard evaluation and recompilation.
|
| 209 |
+
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
| 210 |
+
# it seems Dynamo reuse the compilation across instances,
|
| 211 |
+
# while we need to make sure the compiled code is not reused.
|
| 212 |
+
# we need to control all the compilation of the model.
|
| 213 |
+
torch._dynamo.eval_frame.remove_from_cache(
|
| 214 |
+
self.original_code_object)
|
| 215 |
+
|
| 216 |
+
# collect all relevant files traced by Dynamo,
|
| 217 |
+
# so that the compilation cache can trigger re-compilation
|
| 218 |
+
# properly when any of these files change.
|
| 219 |
+
|
| 220 |
+
# 1. the file containing the top-level forward function
|
| 221 |
+
self.vllm_config.compilation_config.traced_files.add(
|
| 222 |
+
self.original_code_object.co_filename)
|
| 223 |
+
|
| 224 |
+
# 2. every time Dynamo sees a function call, it will inline
|
| 225 |
+
# the function by calling InliningInstructionTranslator.inline_call
|
| 226 |
+
# we hijack this function to know all the functions called
|
| 227 |
+
# during Dynamo tracing, and their corresponding files
|
| 228 |
+
inline_call = InliningInstructionTranslator.inline_call
|
| 229 |
+
|
| 230 |
+
def patched_inline_call(parent, func, args, kwargs):
|
| 231 |
+
code = func.get_code()
|
| 232 |
+
self.vllm_config.compilation_config.traced_files.add(
|
| 233 |
+
code.co_filename)
|
| 234 |
+
return inline_call(parent, func, args, kwargs)
|
| 235 |
+
|
| 236 |
+
with patch.object(InliningInstructionTranslator, 'inline_call',
|
| 237 |
+
patched_inline_call):
|
| 238 |
+
output = self.compiled_callable(*args, **kwargs)
|
| 239 |
+
return output
|
| 240 |
+
|
| 241 |
+
# usually, capturing the model once is enough, and then we can
|
| 242 |
+
# dispatch to the compiled code directly, without going through
|
| 243 |
+
# the Dynamo guard mechanism.
|
| 244 |
+
with self.dispatch_to_code(0):
|
| 245 |
+
model_output = self.forward(*args, **kwargs)
|
| 246 |
+
return model_output
|
| 247 |
+
|
| 248 |
+
cls.__call__ = __call__
|
| 249 |
+
return cls
|
.venv/lib/python3.11/site-packages/vllm/compilation/fix_functionalization.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
| 8 |
+
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
|
| 11 |
+
from .fx_utils import is_func
|
| 12 |
+
from .vllm_inductor_pass import VllmInductorPass
|
| 13 |
+
|
| 14 |
+
logger = init_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FixFunctionalizationPass(VllmInductorPass):
|
| 18 |
+
"""
|
| 19 |
+
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
| 20 |
+
After this pass, DCE (dead-code elimination) should never be run,
|
| 21 |
+
as de-functionalized nodes may appear as dead code.
|
| 22 |
+
|
| 23 |
+
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __call__(self, graph: torch.fx.Graph):
|
| 27 |
+
self.begin()
|
| 28 |
+
self.dump_graph(graph, "before_fix_functionalization")
|
| 29 |
+
|
| 30 |
+
self.nodes_to_remove: List[torch.fx.Node] = []
|
| 31 |
+
count = 0
|
| 32 |
+
for node in graph.nodes:
|
| 33 |
+
if not is_func(node, auto_functionalized):
|
| 34 |
+
continue # Avoid deep if-elif nesting
|
| 35 |
+
|
| 36 |
+
kwargs = node.kwargs
|
| 37 |
+
at_target = node.args[0]
|
| 38 |
+
|
| 39 |
+
if at_target == torch.ops._C.rotary_embedding.default:
|
| 40 |
+
query = kwargs['query']
|
| 41 |
+
mm_node = query.args[0].args[0]
|
| 42 |
+
|
| 43 |
+
# rotary_embedding is a special case: the two mutating inputs
|
| 44 |
+
# are query and key, which are slices of mm_node.
|
| 45 |
+
# While functionalized, results at[1] and at[2] are scattered
|
| 46 |
+
# back into mm_node. After de-functionalization, we can just
|
| 47 |
+
# use mm_node directly.
|
| 48 |
+
for idx, user in self.getitem_users(node).items():
|
| 49 |
+
for user_of_getitem in user.users:
|
| 50 |
+
if is_func(user_of_getitem,
|
| 51 |
+
torch.ops.aten.slice_scatter.default):
|
| 52 |
+
user_of_getitem.replace_all_uses_with(mm_node)
|
| 53 |
+
self._remove(user_of_getitem)
|
| 54 |
+
self._remove(user)
|
| 55 |
+
|
| 56 |
+
self.insert_defunctionalized(graph, node)
|
| 57 |
+
self._remove(node)
|
| 58 |
+
|
| 59 |
+
# rms_norm replacements avoid the most copies for LLaMa.
|
| 60 |
+
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
| 61 |
+
mutated_args = {1: 'input', 2: 'residual'}
|
| 62 |
+
self.defunctionalize(graph, node, mutated_args)
|
| 63 |
+
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
| 64 |
+
mutated_args = {1: 'result', 2: 'residual'}
|
| 65 |
+
self.defunctionalize(graph, node, mutated_args)
|
| 66 |
+
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
| 67 |
+
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
|
| 68 |
+
self.defunctionalize(graph, node, mutated_args)
|
| 69 |
+
elif at_target in [
|
| 70 |
+
torch.ops._C.rms_norm.default,
|
| 71 |
+
torch.ops._C.rms_norm_static_fp8_quant.default
|
| 72 |
+
]:
|
| 73 |
+
mutated_args = {1: 'result'}
|
| 74 |
+
self.defunctionalize(graph, node, mutated_args)
|
| 75 |
+
|
| 76 |
+
elif at_target == torch.ops._C.silu_and_mul.default:
|
| 77 |
+
mutated_args = {1: 'out'}
|
| 78 |
+
# Because we have an 'out', need to specify args directly
|
| 79 |
+
self.defunctionalize(graph,
|
| 80 |
+
node,
|
| 81 |
+
mutated_args,
|
| 82 |
+
args=('out', 'input'))
|
| 83 |
+
else:
|
| 84 |
+
continue # skip the count
|
| 85 |
+
|
| 86 |
+
count += 1
|
| 87 |
+
|
| 88 |
+
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
| 89 |
+
|
| 90 |
+
# Remove the nodes all at once
|
| 91 |
+
count_removed = len(self.nodes_to_remove)
|
| 92 |
+
for node in self.nodes_to_remove:
|
| 93 |
+
graph.erase_node(node)
|
| 94 |
+
|
| 95 |
+
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
| 96 |
+
count_removed)
|
| 97 |
+
self.dump_graph(graph, "after_fix_functionalization")
|
| 98 |
+
self.end_and_log()
|
| 99 |
+
|
| 100 |
+
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
| 101 |
+
Iterable[torch.fx.Node]]):
|
| 102 |
+
"""
|
| 103 |
+
Stage a node (or nodes) for removal at the end of the pass.
|
| 104 |
+
"""
|
| 105 |
+
if isinstance(node_or_nodes, torch.fx.Node):
|
| 106 |
+
self.nodes_to_remove.append(node_or_nodes)
|
| 107 |
+
else:
|
| 108 |
+
self.nodes_to_remove.extend(node_or_nodes)
|
| 109 |
+
|
| 110 |
+
def defunctionalize(self,
|
| 111 |
+
graph: torch.fx.Graph,
|
| 112 |
+
node: torch.fx.Node,
|
| 113 |
+
mutated_args: Dict[int, Union[torch.fx.Node, str]],
|
| 114 |
+
args: Optional[Tuple[Union[torch.fx.Node, str],
|
| 115 |
+
...]] = None):
|
| 116 |
+
"""
|
| 117 |
+
De-functionalize a node by replacing it with a call to the original.
|
| 118 |
+
It also replaces the getitem users with the mutated arguments.
|
| 119 |
+
See replace_users_with_mutated_args and insert_defunctionalized.
|
| 120 |
+
"""
|
| 121 |
+
self.replace_users_with_mutated_args(node, mutated_args)
|
| 122 |
+
self.insert_defunctionalized(graph, node, args=args)
|
| 123 |
+
self._remove(node)
|
| 124 |
+
|
| 125 |
+
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
| 126 |
+
mutated_args: Dict[int,
|
| 127 |
+
Union[torch.fx.Node,
|
| 128 |
+
str]]):
|
| 129 |
+
"""
|
| 130 |
+
Replace all getitem users of the auto-functionalized node with the
|
| 131 |
+
mutated arguments.
|
| 132 |
+
:param node: The auto-functionalized node
|
| 133 |
+
:param mutated_args: The mutated arguments, indexed by getitem index.
|
| 134 |
+
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
| 135 |
+
"""
|
| 136 |
+
for idx, user in self.getitem_users(node).items():
|
| 137 |
+
arg = mutated_args[idx]
|
| 138 |
+
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
| 139 |
+
user.replace_all_uses_with(arg)
|
| 140 |
+
self._remove(user)
|
| 141 |
+
|
| 142 |
+
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
|
| 143 |
+
"""
|
| 144 |
+
Returns the operator.getitem users of the auto-functionalized node,
|
| 145 |
+
indexed by the index they are getting.
|
| 146 |
+
"""
|
| 147 |
+
users = {}
|
| 148 |
+
for user in node.users:
|
| 149 |
+
if is_func(user, operator.getitem):
|
| 150 |
+
idx = user.args[1]
|
| 151 |
+
users[idx] = user
|
| 152 |
+
return users
|
| 153 |
+
|
| 154 |
+
def insert_defunctionalized(self,
|
| 155 |
+
graph: torch.fx.Graph,
|
| 156 |
+
node: torch.fx.Node,
|
| 157 |
+
args: Optional[Tuple[Union[torch.fx.Node, str],
|
| 158 |
+
...]] = None):
|
| 159 |
+
"""
|
| 160 |
+
Insert a new defunctionalized node into the graph before node.
|
| 161 |
+
If one of the kwargs is 'out', provide args directly,
|
| 162 |
+
as node.kwargs cannot be used.
|
| 163 |
+
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
| 164 |
+
|
| 165 |
+
:param graph: Graph to insert the defunctionalized node into
|
| 166 |
+
:param node: The auto-functionalized node to defunctionalize
|
| 167 |
+
:param args: If we cannot use kwargs, specify args directly.
|
| 168 |
+
If an arg is a string, `node.kwargs[arg]` is used.
|
| 169 |
+
""" # noqa: E501
|
| 170 |
+
assert is_func(node, auto_functionalized), \
|
| 171 |
+
f"node must be auto-functionalized, is {node} instead"
|
| 172 |
+
|
| 173 |
+
# Create a new call to the original function
|
| 174 |
+
with graph.inserting_before(node):
|
| 175 |
+
function = node.args[0]
|
| 176 |
+
if args is None:
|
| 177 |
+
graph.call_function(function, kwargs=node.kwargs)
|
| 178 |
+
else:
|
| 179 |
+
# Args passed as strings refer to items in node.kwargs
|
| 180 |
+
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
|
| 181 |
+
for arg in args)
|
| 182 |
+
graph.call_function(function, args=args)
|
.venv/lib/python3.11/site-packages/vllm/compilation/fusion.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch._inductor.pattern_matcher as pm
|
| 7 |
+
# TODO(luka) use vllm.utils once #10836 landed
|
| 8 |
+
from compressed_tensors.quantization import FP8_DTYPE
|
| 9 |
+
from torch import fx
|
| 10 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
| 11 |
+
from torch._inductor.pattern_matcher import PatternMatcherPass
|
| 12 |
+
from torch._ops import OpOverload
|
| 13 |
+
|
| 14 |
+
from vllm.config import CompilationConfig
|
| 15 |
+
from vllm.logger import init_logger
|
| 16 |
+
|
| 17 |
+
from .fx_utils import find_getitem_maybe
|
| 18 |
+
from .multi_output_match import MultiOutputMatch
|
| 19 |
+
from .vllm_inductor_pass import VllmInductorPass
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def empty_bf16(*args, **kwargs):
|
| 25 |
+
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def empty_fp32(*args, **kwargs):
|
| 29 |
+
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
RMS_OP = torch.ops._C.rms_norm.default
|
| 33 |
+
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class QuantKey(NamedTuple):
|
| 37 |
+
"""
|
| 38 |
+
Named tuple for identifying the type of quantization.
|
| 39 |
+
dtype: quantized data type
|
| 40 |
+
static: static quantization if True, dynamic if False
|
| 41 |
+
per_tensor: per-tensor quantization if True, per-token if False
|
| 42 |
+
symmetric: symmetric if True, asymmetric if False
|
| 43 |
+
"""
|
| 44 |
+
dtype: torch.dtype
|
| 45 |
+
static: bool
|
| 46 |
+
per_tensor: bool = True
|
| 47 |
+
symmetric: bool = True
|
| 48 |
+
|
| 49 |
+
def __str__(self):
|
| 50 |
+
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
| 51 |
+
f"{fx.graph.dtype_abbrs[self.dtype]},"
|
| 52 |
+
f"{'per_tensor' if self.per_tensor else 'per_token'},"
|
| 53 |
+
f"{'a' if not self.symmetric else ''}symmetric)")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
| 57 |
+
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
| 58 |
+
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
| 59 |
+
|
| 60 |
+
QUANT_OPS: Dict[QuantKey, OpOverload] = {
|
| 61 |
+
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
| 62 |
+
kFp8DynamicTensorSym:
|
| 63 |
+
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
| 64 |
+
kFp8DynamicTokenSym:
|
| 65 |
+
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class FusedRMSQuantKey(NamedTuple):
|
| 70 |
+
"""
|
| 71 |
+
Named tuple for identifying the type of RMSNorm + quant fusion.
|
| 72 |
+
quant: type of quantization
|
| 73 |
+
fused_add: does the op also perform the residual add
|
| 74 |
+
"""
|
| 75 |
+
quant: QuantKey
|
| 76 |
+
fused_add: bool
|
| 77 |
+
|
| 78 |
+
def __str__(self):
|
| 79 |
+
return (f"FusedQuantKey({self.quant}, with"
|
| 80 |
+
f"{'' if self.fused_add else 'out'} residual)")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
|
| 84 |
+
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
| 85 |
+
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
| 86 |
+
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
| 87 |
+
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
|
| 88 |
+
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
| 89 |
+
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
| 90 |
+
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
| 91 |
+
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class QuantMultiOutputMatch(MultiOutputMatch):
|
| 96 |
+
|
| 97 |
+
def __init__(self, match: pm.Match, quant_op, fused_op):
|
| 98 |
+
super().__init__(match)
|
| 99 |
+
assert isinstance(quant_op, OpOverload)
|
| 100 |
+
assert isinstance(fused_op, OpOverload)
|
| 101 |
+
self.QUANT_OP = quant_op # in-place quant op
|
| 102 |
+
self.FUSED_OP = fused_op # in-place fused quant op
|
| 103 |
+
|
| 104 |
+
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
|
| 105 |
+
int]],
|
| 106 |
+
**kwargs):
|
| 107 |
+
"""
|
| 108 |
+
This utility function inserts an auto-functionalized node for FUSED_OP.
|
| 109 |
+
It also correctly sets its meta value and rebinds the users of the
|
| 110 |
+
unfused nodes to use the fused node instead.
|
| 111 |
+
|
| 112 |
+
:param fused_return_mapping: A dictionary, mapping from getitem indices
|
| 113 |
+
of the fused node result to a tuple of the old node and a getitem index.
|
| 114 |
+
:param kwargs: kwargs that get directly forwarded to the auto_fn node
|
| 115 |
+
|
| 116 |
+
Example:
|
| 117 |
+
If we want to replace this graph:
|
| 118 |
+
_, x1, x2 = auto_fn(op1)
|
| 119 |
+
_, y1, y2 = auto_fn(op2)
|
| 120 |
+
|
| 121 |
+
with
|
| 122 |
+
_, x1, y2, x2 = auto_fn(FUSED_OP)
|
| 123 |
+
|
| 124 |
+
we would call:
|
| 125 |
+
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
|
| 126 |
+
|
| 127 |
+
Note that the 0th element is None for auto-functionalized in-place ops.
|
| 128 |
+
Hence, others appear 1-indexed.
|
| 129 |
+
"""
|
| 130 |
+
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
|
| 131 |
+
indices = fused_return_mapping.keys()
|
| 132 |
+
getitem_nodes = self.insert_getitems(fused_node, indices)
|
| 133 |
+
|
| 134 |
+
# Prepare the meta value, use a list so it's mutable
|
| 135 |
+
meta_val = [None] * (max(indices) + 1)
|
| 136 |
+
|
| 137 |
+
# Iterate through elements of the tuple produced by fused_node
|
| 138 |
+
for idx, getitem_node in zip(indices, getitem_nodes):
|
| 139 |
+
old_node, old_idx = fused_return_mapping[idx]
|
| 140 |
+
|
| 141 |
+
# If the old value was never used, the old_getitem might not exist
|
| 142 |
+
old_getitem = find_getitem_maybe(old_node, old_idx)
|
| 143 |
+
if old_getitem is not None:
|
| 144 |
+
# Rebind the users of match getitem nodes to use the new nodes.
|
| 145 |
+
# The old nodes will be removed by DCE at the end of the pass.
|
| 146 |
+
old_getitem.replace_all_uses_with(getitem_node)
|
| 147 |
+
getitem_node.meta["val"] = old_getitem.meta["val"]
|
| 148 |
+
|
| 149 |
+
# Extract the appropriate meta value
|
| 150 |
+
# It is present even if the getitem node does not exist
|
| 151 |
+
meta_val[idx] = old_node.meta["val"][old_idx]
|
| 152 |
+
|
| 153 |
+
# Fix the meta value on the new fused node
|
| 154 |
+
fused_node.meta["val"] = tuple(meta_val)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class RMSNormQuantPattern:
|
| 158 |
+
|
| 159 |
+
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
| 160 |
+
self.epsilon = epsilon
|
| 161 |
+
self.quant_dtype = key.quant.dtype
|
| 162 |
+
|
| 163 |
+
assert key.quant in QUANT_OPS, \
|
| 164 |
+
f"unsupported quantization scheme {key.quant}"
|
| 165 |
+
self.QUANT_OP = QUANT_OPS[key.quant]
|
| 166 |
+
|
| 167 |
+
assert key in FUSED_OPS, \
|
| 168 |
+
f"unsupported fused rmsnorm+quant op for {key}"
|
| 169 |
+
self.FUSED_OP = FUSED_OPS[key]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
| 173 |
+
|
| 174 |
+
def __init__(self,
|
| 175 |
+
epsilon: float,
|
| 176 |
+
quant_dtype: torch.dtype,
|
| 177 |
+
symmetric=True):
|
| 178 |
+
fused_key = FusedRMSQuantKey(fused_add=False,
|
| 179 |
+
quant=QuantKey(dtype=quant_dtype,
|
| 180 |
+
static=True,
|
| 181 |
+
per_tensor=True,
|
| 182 |
+
symmetric=symmetric))
|
| 183 |
+
super().__init__(epsilon, fused_key)
|
| 184 |
+
|
| 185 |
+
def register(self, pm_pass: PatternMatcherPass):
|
| 186 |
+
# Cannot use methods, as the self argument affects tracing
|
| 187 |
+
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
| 188 |
+
input: torch.Tensor, weight: torch.Tensor,
|
| 189 |
+
scale: torch.Tensor):
|
| 190 |
+
at1 = auto_functionalized(RMS_OP,
|
| 191 |
+
result=result_rms,
|
| 192 |
+
input=input,
|
| 193 |
+
weight=weight,
|
| 194 |
+
epsilon=self.epsilon)
|
| 195 |
+
at2 = auto_functionalized(self.QUANT_OP,
|
| 196 |
+
result=result,
|
| 197 |
+
input=at1[1],
|
| 198 |
+
scale=scale)
|
| 199 |
+
|
| 200 |
+
# result
|
| 201 |
+
return at2[1]
|
| 202 |
+
|
| 203 |
+
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
| 204 |
+
input: torch.Tensor, weight: torch.Tensor,
|
| 205 |
+
scale: torch.Tensor):
|
| 206 |
+
at = auto_functionalized(self.FUSED_OP,
|
| 207 |
+
result=result,
|
| 208 |
+
input=input,
|
| 209 |
+
weight=weight,
|
| 210 |
+
scale=scale,
|
| 211 |
+
epsilon=self.epsilon)
|
| 212 |
+
|
| 213 |
+
# result
|
| 214 |
+
return at[1]
|
| 215 |
+
|
| 216 |
+
inputs = [
|
| 217 |
+
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
| 218 |
+
empty_bf16(5, 4), # result_rms
|
| 219 |
+
empty_bf16(5, 4), # input
|
| 220 |
+
empty_bf16(1, 5), # weight
|
| 221 |
+
empty_fp32(1, 1) # scale
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
|
| 225 |
+
pm_pass)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
| 229 |
+
|
| 230 |
+
def __init__(self,
|
| 231 |
+
epsilon: float,
|
| 232 |
+
quant_dtype: torch.dtype,
|
| 233 |
+
symmetric=True):
|
| 234 |
+
key = FusedRMSQuantKey(fused_add=True,
|
| 235 |
+
quant=QuantKey(dtype=quant_dtype,
|
| 236 |
+
static=True,
|
| 237 |
+
per_tensor=True,
|
| 238 |
+
symmetric=symmetric))
|
| 239 |
+
super().__init__(epsilon, key)
|
| 240 |
+
|
| 241 |
+
def register(self, pm_pass: PatternMatcherPass,
|
| 242 |
+
record_match: Callable[[MultiOutputMatch], bool]):
|
| 243 |
+
|
| 244 |
+
def pattern(result: torch.Tensor, input: torch.Tensor,
|
| 245 |
+
residual: torch.Tensor, weight: torch.Tensor,
|
| 246 |
+
scale: torch.Tensor):
|
| 247 |
+
at = auto_functionalized(RMS_ADD_OP,
|
| 248 |
+
input=input,
|
| 249 |
+
residual=residual,
|
| 250 |
+
weight=weight,
|
| 251 |
+
epsilon=self.epsilon)
|
| 252 |
+
at1 = auto_functionalized(self.QUANT_OP,
|
| 253 |
+
result=result,
|
| 254 |
+
input=at[1],
|
| 255 |
+
scale=scale)
|
| 256 |
+
|
| 257 |
+
# result, residual
|
| 258 |
+
return at1[1], at[2]
|
| 259 |
+
|
| 260 |
+
def replacement(result: torch.Tensor, input: torch.Tensor,
|
| 261 |
+
residual: torch.Tensor, weight: torch.Tensor,
|
| 262 |
+
scale: torch.Tensor):
|
| 263 |
+
at = auto_functionalized(self.FUSED_OP,
|
| 264 |
+
result=result,
|
| 265 |
+
input=input,
|
| 266 |
+
residual=residual,
|
| 267 |
+
weight=weight,
|
| 268 |
+
scale=scale,
|
| 269 |
+
epsilon=self.epsilon)
|
| 270 |
+
|
| 271 |
+
# result, residual
|
| 272 |
+
return at[1], at[2]
|
| 273 |
+
|
| 274 |
+
inputs = [
|
| 275 |
+
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
| 276 |
+
empty_bf16(5, 4), # input
|
| 277 |
+
empty_bf16(5, 4), # residual
|
| 278 |
+
empty_bf16(1, 5), # weight
|
| 279 |
+
empty_fp32(1, 1) # scale
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
pm.register_replacement(
|
| 283 |
+
pattern,
|
| 284 |
+
replacement,
|
| 285 |
+
inputs,
|
| 286 |
+
pm.fwd_only,
|
| 287 |
+
pm_pass,
|
| 288 |
+
extra_check=lambda m: record_match(
|
| 289 |
+
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
| 290 |
+
|
| 291 |
+
class Match(QuantMultiOutputMatch):
|
| 292 |
+
|
| 293 |
+
def process(self):
|
| 294 |
+
# Find the nodes in the match that we need to rebind
|
| 295 |
+
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
| 296 |
+
quant_node = self.find_auto_fn(self.QUANT_OP)
|
| 297 |
+
|
| 298 |
+
assert len(rms_node.users) == 2
|
| 299 |
+
assert len(quant_node.users) == 1
|
| 300 |
+
|
| 301 |
+
# First, insert a new auto_functionalized node for the fused op,
|
| 302 |
+
# as well as getitem nodes to extract the result and residual.
|
| 303 |
+
# The auto_fn node returns a tuple of (None, result, residual).
|
| 304 |
+
#
|
| 305 |
+
# The resulting graph looks like this:
|
| 306 |
+
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
| 307 |
+
# result_node_new = at[1]
|
| 308 |
+
# residual_node_new = at[2]
|
| 309 |
+
with self.inserting_after_match():
|
| 310 |
+
# Missing epsilon, scalars cannot be inputs to the pattern
|
| 311 |
+
kwargs = self.match.kwargs.copy()
|
| 312 |
+
|
| 313 |
+
# 0 is always None
|
| 314 |
+
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
| 315 |
+
self.insert_fused_node(fused_return_mapping,
|
| 316 |
+
epsilon=rms_node.kwargs["epsilon"],
|
| 317 |
+
**kwargs)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
| 321 |
+
|
| 322 |
+
def __init__(self,
|
| 323 |
+
epsilon: float,
|
| 324 |
+
quant_dtype: torch.dtype,
|
| 325 |
+
per_tensor: bool,
|
| 326 |
+
symmetric=True):
|
| 327 |
+
key = FusedRMSQuantKey(fused_add=False,
|
| 328 |
+
quant=QuantKey(dtype=quant_dtype,
|
| 329 |
+
static=False,
|
| 330 |
+
per_tensor=per_tensor,
|
| 331 |
+
symmetric=symmetric))
|
| 332 |
+
super().__init__(epsilon, key)
|
| 333 |
+
|
| 334 |
+
def register(self, pm_pass: PatternMatcherPass,
|
| 335 |
+
record_match: Callable[[MultiOutputMatch], bool]):
|
| 336 |
+
|
| 337 |
+
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
| 338 |
+
input: torch.Tensor, weight: torch.Tensor,
|
| 339 |
+
scale: torch.Tensor):
|
| 340 |
+
at1 = auto_functionalized(RMS_OP,
|
| 341 |
+
result=result_rms,
|
| 342 |
+
input=input,
|
| 343 |
+
weight=weight,
|
| 344 |
+
epsilon=self.epsilon)
|
| 345 |
+
at2 = auto_functionalized(self.QUANT_OP,
|
| 346 |
+
result=result,
|
| 347 |
+
input=at1[1],
|
| 348 |
+
scale=scale,
|
| 349 |
+
scale_ub=None)
|
| 350 |
+
|
| 351 |
+
# result, scale
|
| 352 |
+
return at2[1], at2[2]
|
| 353 |
+
|
| 354 |
+
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
| 355 |
+
input: torch.Tensor, weight: torch.Tensor,
|
| 356 |
+
scale: torch.Tensor):
|
| 357 |
+
at = auto_functionalized(self.FUSED_OP,
|
| 358 |
+
result=result,
|
| 359 |
+
input=input,
|
| 360 |
+
weight=weight,
|
| 361 |
+
scale=scale,
|
| 362 |
+
epsilon=self.epsilon,
|
| 363 |
+
scale_ub=None,
|
| 364 |
+
residual=None)
|
| 365 |
+
|
| 366 |
+
# result, scale
|
| 367 |
+
return at[1], at[2]
|
| 368 |
+
|
| 369 |
+
inputs = [
|
| 370 |
+
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
| 371 |
+
empty_bf16(5, 4), # result_rms
|
| 372 |
+
empty_bf16(5, 4), # input
|
| 373 |
+
empty_bf16(1, 5), # weight
|
| 374 |
+
empty_fp32(1, 1) # scale
|
| 375 |
+
]
|
| 376 |
+
|
| 377 |
+
pm.register_replacement(
|
| 378 |
+
pattern,
|
| 379 |
+
replacement,
|
| 380 |
+
inputs,
|
| 381 |
+
pm.fwd_only,
|
| 382 |
+
pm_pass,
|
| 383 |
+
extra_check=lambda m: record_match(
|
| 384 |
+
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
| 385 |
+
|
| 386 |
+
class Match(QuantMultiOutputMatch):
|
| 387 |
+
|
| 388 |
+
def process(self):
|
| 389 |
+
# Find the nodes in the match that we need to rebind
|
| 390 |
+
rms_node = self.find_auto_fn(RMS_OP)
|
| 391 |
+
quant_node = self.find_auto_fn(self.QUANT_OP)
|
| 392 |
+
|
| 393 |
+
assert len(rms_node.users) == 1
|
| 394 |
+
assert len(quant_node.users) == 2
|
| 395 |
+
|
| 396 |
+
# First, insert a new auto_functionalized node for the fused op,
|
| 397 |
+
# as well as getitem nodes to extract the result and scale.
|
| 398 |
+
# The auto_fn node returns a tuple of (None, result, scale).
|
| 399 |
+
#
|
| 400 |
+
# The resulting graph looks like this:
|
| 401 |
+
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
| 402 |
+
# result_node_new = at[1]
|
| 403 |
+
# scale_node_new = at[2]
|
| 404 |
+
with self.inserting_after_match():
|
| 405 |
+
# Missing epsilon, scalars cannot be inputs to the pattern
|
| 406 |
+
kwargs = self.match.kwargs.copy()
|
| 407 |
+
del kwargs["result_rms"] # not used in the fused op
|
| 408 |
+
|
| 409 |
+
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
|
| 410 |
+
self.insert_fused_node(
|
| 411 |
+
fused_return_mapping,
|
| 412 |
+
epsilon=rms_node.kwargs["epsilon"],
|
| 413 |
+
scale_ub=None, # not used but required
|
| 414 |
+
residual=None, # not used but required
|
| 415 |
+
**kwargs)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
| 419 |
+
|
| 420 |
+
def __init__(self,
|
| 421 |
+
epsilon: float,
|
| 422 |
+
quant_dtype: torch.dtype,
|
| 423 |
+
per_tensor: bool = True,
|
| 424 |
+
symmetric=True):
|
| 425 |
+
key = FusedRMSQuantKey(fused_add=True,
|
| 426 |
+
quant=QuantKey(dtype=quant_dtype,
|
| 427 |
+
static=False,
|
| 428 |
+
per_tensor=per_tensor,
|
| 429 |
+
symmetric=symmetric))
|
| 430 |
+
super().__init__(epsilon, key)
|
| 431 |
+
|
| 432 |
+
def register(self, pm_pass: PatternMatcherPass,
|
| 433 |
+
record_match: Callable[[MultiOutputMatch], bool]):
|
| 434 |
+
|
| 435 |
+
def pattern(result: torch.Tensor, input: torch.Tensor,
|
| 436 |
+
residual: torch.Tensor, weight: torch.Tensor,
|
| 437 |
+
scale: torch.Tensor):
|
| 438 |
+
at = auto_functionalized(RMS_ADD_OP,
|
| 439 |
+
input=input,
|
| 440 |
+
residual=residual,
|
| 441 |
+
weight=weight,
|
| 442 |
+
epsilon=self.epsilon)
|
| 443 |
+
at1 = auto_functionalized(self.QUANT_OP,
|
| 444 |
+
result=result,
|
| 445 |
+
input=at[1],
|
| 446 |
+
scale=scale,
|
| 447 |
+
scale_ub=None)
|
| 448 |
+
|
| 449 |
+
# result, residual, scale
|
| 450 |
+
return at1[1], at[2], at1[2]
|
| 451 |
+
|
| 452 |
+
def replacement(result: torch.Tensor, input: torch.Tensor,
|
| 453 |
+
residual: torch.Tensor, weight: torch.Tensor,
|
| 454 |
+
scale: torch.Tensor):
|
| 455 |
+
at = auto_functionalized(self.FUSED_OP,
|
| 456 |
+
result=result,
|
| 457 |
+
input=input,
|
| 458 |
+
weight=weight,
|
| 459 |
+
scale=scale,
|
| 460 |
+
epsilon=self.epsilon,
|
| 461 |
+
scale_ub=None,
|
| 462 |
+
residual=residual)
|
| 463 |
+
|
| 464 |
+
# result, residual, scale
|
| 465 |
+
return at[1], at[3], at[2]
|
| 466 |
+
|
| 467 |
+
inputs = [
|
| 468 |
+
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
| 469 |
+
empty_bf16(5, 4), # input
|
| 470 |
+
empty_bf16(5, 4), # residual
|
| 471 |
+
empty_bf16(1, 5), # weight
|
| 472 |
+
empty_fp32(1, 1) # scale
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
pm.register_replacement(
|
| 476 |
+
pattern,
|
| 477 |
+
replacement,
|
| 478 |
+
inputs,
|
| 479 |
+
pm.fwd_only,
|
| 480 |
+
pm_pass,
|
| 481 |
+
extra_check=lambda m: record_match(
|
| 482 |
+
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
| 483 |
+
|
| 484 |
+
class Match(QuantMultiOutputMatch):
|
| 485 |
+
|
| 486 |
+
def process(self):
|
| 487 |
+
# Find the nodes in the match that we need to rebind
|
| 488 |
+
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
| 489 |
+
quant_node = self.find_auto_fn(self.QUANT_OP)
|
| 490 |
+
|
| 491 |
+
assert len(rms_node.users) == 2
|
| 492 |
+
assert len(quant_node.users) == 2
|
| 493 |
+
|
| 494 |
+
# First, insert a new auto_functionalized node for the fused op,
|
| 495 |
+
# as well as getitem nodes to extract result, scale, and residual.
|
| 496 |
+
# The auto_fn node returns a tuple (None, result, scale, residual).
|
| 497 |
+
#
|
| 498 |
+
# The resulting graph looks like this:
|
| 499 |
+
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
| 500 |
+
# result_node_new = at[1]
|
| 501 |
+
# scale_node_new = at[2]
|
| 502 |
+
# residual_node_new = at[3]
|
| 503 |
+
with self.inserting_after_match():
|
| 504 |
+
# Missing epsilon, scalars cannot be inputs to the pattern
|
| 505 |
+
kwargs = self.match.kwargs.copy()
|
| 506 |
+
|
| 507 |
+
fused_return_mapping = {
|
| 508 |
+
1: (quant_node, 1), # result
|
| 509 |
+
2: (quant_node, 2), # scale
|
| 510 |
+
3: (rms_node, 2), # residual
|
| 511 |
+
}
|
| 512 |
+
self.insert_fused_node(
|
| 513 |
+
fused_return_mapping,
|
| 514 |
+
epsilon=rms_node.kwargs["epsilon"],
|
| 515 |
+
scale_ub=None, # not used but required
|
| 516 |
+
**kwargs)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class FusionPass(VllmInductorPass):
|
| 520 |
+
"""
|
| 521 |
+
This pass fuses a pre-defined set of custom ops into fused ops.
|
| 522 |
+
It uses the torch pattern matcher to find the patterns and replace them.
|
| 523 |
+
It also manually processes multi-output matches, as those are broken in
|
| 524 |
+
the torch pattern matcher.
|
| 525 |
+
|
| 526 |
+
Because patterns can only be registered once, the pass is a singleton.
|
| 527 |
+
This will be addressed in a future version of PyTorch:
|
| 528 |
+
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
_instance: 'Optional[FusionPass]' = None
|
| 532 |
+
|
| 533 |
+
@classmethod
|
| 534 |
+
def instance(cls, config: CompilationConfig.PassConfig):
|
| 535 |
+
"""
|
| 536 |
+
Get the singleton instance of the FusionPass.
|
| 537 |
+
If the instance exists, the config is updated but
|
| 538 |
+
initialization is not repeated.
|
| 539 |
+
"""
|
| 540 |
+
if cls._instance is None:
|
| 541 |
+
cls._instance = FusionPass(config)
|
| 542 |
+
else:
|
| 543 |
+
cls._instance.config = config
|
| 544 |
+
return cls._instance
|
| 545 |
+
|
| 546 |
+
def __init__(self, config: CompilationConfig.PassConfig):
|
| 547 |
+
assert self.__class__._instance is None, \
|
| 548 |
+
"FusionPass singleton instance already exists"
|
| 549 |
+
super().__init__(config)
|
| 550 |
+
|
| 551 |
+
self.matches: List[MultiOutputMatch] = []
|
| 552 |
+
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
| 553 |
+
pass_name="fusion_pass")
|
| 554 |
+
|
| 555 |
+
for epsilon in [1e-5, 1e-6]:
|
| 556 |
+
# Fuse rms_norm + static fp8 quant
|
| 557 |
+
RMSNormStaticQuantPattern(epsilon,
|
| 558 |
+
FP8_DTYPE).register(self.patterns)
|
| 559 |
+
|
| 560 |
+
# Matches for patterns below have 2 or more outputs,
|
| 561 |
+
# so we need to process them manually (see process_matches)
|
| 562 |
+
|
| 563 |
+
# Fuse rms_norm + static fp8 quant
|
| 564 |
+
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
| 565 |
+
self.patterns, self.record_match)
|
| 566 |
+
|
| 567 |
+
# Fuse rms_norm + dynamic per-token fp8 quant
|
| 568 |
+
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
|
| 569 |
+
per_tensor=False).register(
|
| 570 |
+
self.patterns, self.record_match)
|
| 571 |
+
|
| 572 |
+
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
| 573 |
+
FusedAddRMSNormDynamicQuantPattern(epsilon,
|
| 574 |
+
FP8_DTYPE,
|
| 575 |
+
per_tensor=False).register(
|
| 576 |
+
self.patterns,
|
| 577 |
+
self.record_match)
|
| 578 |
+
|
| 579 |
+
# WARNING: This is a hack to clear the pattern matcher cache
|
| 580 |
+
# and allow multiple values of epsilon.
|
| 581 |
+
torch._inductor.pattern_matcher._seen_patterns.clear()
|
| 582 |
+
|
| 583 |
+
def record_match(self, match: MultiOutputMatch) -> bool:
|
| 584 |
+
# Hijack the extra_check to record the match and
|
| 585 |
+
# save it for post-processing.
|
| 586 |
+
self.matches.append(match)
|
| 587 |
+
|
| 588 |
+
# Return False to prevent automatic replacement.
|
| 589 |
+
return False
|
| 590 |
+
|
| 591 |
+
def process_matches(self, graph: fx.Graph):
|
| 592 |
+
"""
|
| 593 |
+
Manually process multi-output matches and replace them with fused nodes.
|
| 594 |
+
See MultiOutputMatch for more details.
|
| 595 |
+
"""
|
| 596 |
+
for match in self.matches:
|
| 597 |
+
match.process()
|
| 598 |
+
|
| 599 |
+
# Finally, remove matched nodes
|
| 600 |
+
graph.eliminate_dead_code()
|
| 601 |
+
assert all(node not in graph.nodes for match in self.matches
|
| 602 |
+
for node in match.match.nodes)
|
| 603 |
+
|
| 604 |
+
def __call__(self, graph: fx.Graph):
|
| 605 |
+
self.begin()
|
| 606 |
+
self.dump_graph(graph, "before_fusion")
|
| 607 |
+
|
| 608 |
+
count = self.patterns.apply(graph)
|
| 609 |
+
logger.debug("Replaced %s patterns", count)
|
| 610 |
+
self.dump_graph(graph, "after_pattern_match")
|
| 611 |
+
|
| 612 |
+
# Manually process multi-output matches (and run DCE)
|
| 613 |
+
self.process_matches(graph)
|
| 614 |
+
logger.debug("Post-processed %s matches", len(self.matches))
|
| 615 |
+
self.dump_graph(graph, "after_fusion")
|
| 616 |
+
self.matches.clear()
|
| 617 |
+
self.end_and_log()
|
.venv/lib/python3.11/site-packages/vllm/compilation/fx_utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Iterable, Optional
|
| 5 |
+
|
| 6 |
+
from torch import fx
|
| 7 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
| 8 |
+
from torch._ops import OpOverload
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def is_func(node: fx.Node, target) -> bool:
|
| 12 |
+
return node.op == "call_function" and node.target == target
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Returns the first auto_functionalized node with the given op (if it exists)
|
| 16 |
+
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
| 17 |
+
op: OpOverload) -> Optional[fx.Node]:
|
| 18 |
+
for node in nodes:
|
| 19 |
+
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
| 20 |
+
return node
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Returns the first auto_functionalized node with the given op
|
| 25 |
+
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
| 26 |
+
node = find_auto_fn_maybe(nodes, op)
|
| 27 |
+
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
| 28 |
+
return node
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Returns the getitem node that extracts the idx-th element from node
|
| 32 |
+
# (if it exists)
|
| 33 |
+
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
| 34 |
+
for user in node.users:
|
| 35 |
+
if is_func(user, operator.getitem) and user.args[1] == idx:
|
| 36 |
+
return user
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Returns the getitem node that extracts the idx-th element from node
|
| 41 |
+
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
| 42 |
+
ret = find_getitem_maybe(node, idx)
|
| 43 |
+
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
| 44 |
+
return ret
|
.venv/lib/python3.11/site-packages/vllm/compilation/monitor.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
| 7 |
+
from vllm.logger import init_logger
|
| 8 |
+
|
| 9 |
+
logger = init_logger(__name__)
|
| 10 |
+
|
| 11 |
+
context_manager = None
|
| 12 |
+
torch_compile_start_time: float = 0.0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
| 16 |
+
global torch_compile_start_time
|
| 17 |
+
torch_compile_start_time = time.time()
|
| 18 |
+
|
| 19 |
+
compilation_config: CompilationConfig = vllm_config.compilation_config
|
| 20 |
+
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
| 21 |
+
compilation_config.debug_dump_path:
|
| 22 |
+
import depyf
|
| 23 |
+
path = os.path.join(compilation_config.debug_dump_path,
|
| 24 |
+
f"rank_{vllm_config.parallel_config.rank}")
|
| 25 |
+
global context_manager
|
| 26 |
+
context_manager = depyf.prepare_debug(path)
|
| 27 |
+
context_manager.__enter__()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
| 31 |
+
compilation_config: CompilationConfig = vllm_config.compilation_config
|
| 32 |
+
if compilation_config.level == CompilationLevel.PIECEWISE:
|
| 33 |
+
logger.info("torch.compile takes %.2f s in total",
|
| 34 |
+
compilation_config.compilation_time)
|
| 35 |
+
global context_manager
|
| 36 |
+
if context_manager is not None:
|
| 37 |
+
context_manager.__exit__(None, None, None)
|
| 38 |
+
context_manager = None
|
.venv/lib/python3.11/site-packages/vllm/compilation/pass_manager.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
from torch import fx as fx
|
| 6 |
+
|
| 7 |
+
from vllm.config import CompilationConfig
|
| 8 |
+
from vllm.logger import init_logger
|
| 9 |
+
|
| 10 |
+
from .fix_functionalization import FixFunctionalizationPass
|
| 11 |
+
from .fusion import FusionPass
|
| 12 |
+
from .inductor_pass import InductorPass
|
| 13 |
+
from .reshapes import RedundantReshapesPass
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PostGradPassManager:
|
| 19 |
+
"""
|
| 20 |
+
The pass manager for post-grad passes.
|
| 21 |
+
It handles configuration, adding custom passes, and running passes.
|
| 22 |
+
It also supports pickling, which is used by the Inductor code cache.
|
| 23 |
+
TODO(torch==2.6), use CustomGraphPass
|
| 24 |
+
(torch._inductor.custom_graph_pass.CustomGraphPass)
|
| 25 |
+
|
| 26 |
+
The order of the post-grad post-passes is:
|
| 27 |
+
1. passes (constructor parameter)
|
| 28 |
+
2. default passes (RedundantReshapesPass, FusionPass)
|
| 29 |
+
3. config["post_grad_custom_post_pass"] (if it exists)
|
| 30 |
+
4. fix_functionalization
|
| 31 |
+
This way, all passes operate on a functionalized graph.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.passes: List[InductorPass] = []
|
| 36 |
+
|
| 37 |
+
def __call__(self, graph: fx.Graph):
|
| 38 |
+
for pass_ in self.passes:
|
| 39 |
+
pass_(graph)
|
| 40 |
+
|
| 41 |
+
# always run fix_functionalization last
|
| 42 |
+
self.fix_functionalization(graph)
|
| 43 |
+
|
| 44 |
+
def configure(self, pass_config: CompilationConfig.PassConfig):
|
| 45 |
+
self.pass_config = pass_config
|
| 46 |
+
if pass_config.enable_reshape:
|
| 47 |
+
self.passes += [RedundantReshapesPass(pass_config)]
|
| 48 |
+
|
| 49 |
+
if pass_config.enable_fusion:
|
| 50 |
+
self.passes += [FusionPass.instance(pass_config)]
|
| 51 |
+
|
| 52 |
+
self.fix_functionalization = FixFunctionalizationPass(pass_config)
|
| 53 |
+
|
| 54 |
+
def add(self, pass_: InductorPass):
|
| 55 |
+
assert isinstance(pass_, InductorPass)
|
| 56 |
+
self.passes.append(pass_)
|
| 57 |
+
|
| 58 |
+
def __getstate__(self) -> Dict[str, List[Any]]:
|
| 59 |
+
"""
|
| 60 |
+
Custom pickling for the pass manager, as some passes cannot be pickled.
|
| 61 |
+
Pickling occurs because the pass manager is set as the value of
|
| 62 |
+
`config["post_grad_custom_post_pass"]` in the Inductor config.
|
| 63 |
+
The config is pickled to act as a key in the Inductor code cache.
|
| 64 |
+
Any other passes in the config are pickled as well.
|
| 65 |
+
|
| 66 |
+
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
|
| 67 |
+
"""
|
| 68 |
+
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
| 69 |
+
for pass_ in self.passes:
|
| 70 |
+
state["passes"].append(pass_.uuid())
|
| 71 |
+
state["passes"].append(self.fix_functionalization.uuid())
|
| 72 |
+
return state
|
| 73 |
+
|
| 74 |
+
def __setstate__(self, state):
|
| 75 |
+
"""
|
| 76 |
+
Do not allow unpickling of the pass manager.
|
| 77 |
+
If this is needed in the future, it should properly pickle the passes.
|
| 78 |
+
"""
|
| 79 |
+
raise ValueError("Cannot unpickle PostGradPassManager")
|
.venv/lib/python3.11/site-packages/vllm/compilation/wrapper.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from abc import abstractmethod
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from types import CodeType
|
| 8 |
+
from typing import Callable, List, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import vllm.envs as envs
|
| 13 |
+
from vllm.config import CompilationLevel, get_current_vllm_config
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TorchCompileWrapperWithCustomDispatcher:
|
| 20 |
+
"""
|
| 21 |
+
A wrapper class for torch.compile, with a custom dispatch logic.
|
| 22 |
+
Subclasses should:
|
| 23 |
+
1. Implement the forward method
|
| 24 |
+
2. Implement the dispatch logic in the __call__ method
|
| 25 |
+
It can use `self.compiled_codes` to access the compiled bytecode,
|
| 26 |
+
and `with self.dispatch_to_code(index):` to dispatch to
|
| 27 |
+
the compiled code.
|
| 28 |
+
3. Implement the `__init__` method to determine how to call
|
| 29 |
+
`torch.compile` over the forward method.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self,
|
| 33 |
+
compiled_callable: Optional[Callable] = None,
|
| 34 |
+
compilation_level: int = 0):
|
| 35 |
+
|
| 36 |
+
vllm_config = get_current_vllm_config()
|
| 37 |
+
self.vllm_config = vllm_config
|
| 38 |
+
if compiled_callable is None:
|
| 39 |
+
# default compilation settings
|
| 40 |
+
# compiling the forward method
|
| 41 |
+
|
| 42 |
+
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
| 43 |
+
|
| 44 |
+
compiled_callable = torch.compile(
|
| 45 |
+
self.forward,
|
| 46 |
+
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
| 47 |
+
backend=backend)
|
| 48 |
+
|
| 49 |
+
self.compiled_callable = compiled_callable
|
| 50 |
+
self.original_code_object = self.__class__.forward.__code__
|
| 51 |
+
self.compiled_codes: List[CodeType] = []
|
| 52 |
+
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
| 53 |
+
|
| 54 |
+
# read the env var to determine whether to use the custom dispatcher
|
| 55 |
+
# subclasses can use this to switch between the custom dispatcher
|
| 56 |
+
# and the default Dynamo guard mechanism.
|
| 57 |
+
self.use_custom_dispatcher: bool = \
|
| 58 |
+
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
| 59 |
+
|
| 60 |
+
def __call__(self, *args, **kwargs):
|
| 61 |
+
"""Implement the dispatch logic here, beyond the torch.compile level.
|
| 62 |
+
NOTE: this function can have additional arguments beyond the forward
|
| 63 |
+
method, for directly dispatching to the compiled code.
|
| 64 |
+
"""
|
| 65 |
+
return self.compiled_callable(*args, **kwargs)
|
| 66 |
+
|
| 67 |
+
@abstractmethod
|
| 68 |
+
def forward(self, *args, **kwargs):
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
| 72 |
+
"""Hook to save the compiled bytecode for direct execution."""
|
| 73 |
+
if old_code is not self.original_code_object:
|
| 74 |
+
return
|
| 75 |
+
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
| 76 |
+
frame = sys._getframe()
|
| 77 |
+
while frame and frame.f_back:
|
| 78 |
+
frame = frame.f_back
|
| 79 |
+
code_name = frame.f_code.co_name
|
| 80 |
+
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
| 81 |
+
if code_name == "_compile" and file_name == "convert_frame.py":
|
| 82 |
+
break
|
| 83 |
+
frame = frame.f_locals["frame"]
|
| 84 |
+
assert frame.f_code == old_code
|
| 85 |
+
|
| 86 |
+
if frame.f_locals["self"] is not self:
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
self.compiled_codes.append(new_code)
|
| 90 |
+
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
|
| 91 |
+
if isinstance(local_cache_dir, str):
|
| 92 |
+
decompiled_file = os.path.join(local_cache_dir,
|
| 93 |
+
"transformed_code.py")
|
| 94 |
+
if not os.path.exists(decompiled_file):
|
| 95 |
+
try:
|
| 96 |
+
# usually the decompilation will succeed for most models,
|
| 97 |
+
# as we guarantee a full-graph compilation in Dynamo.
|
| 98 |
+
# but there's no 100% guarantee, since decompliation is
|
| 99 |
+
# not a reversible process.
|
| 100 |
+
import depyf
|
| 101 |
+
src = depyf.decompile(new_code)
|
| 102 |
+
with open(decompiled_file, "w") as f:
|
| 103 |
+
f.write(src)
|
| 104 |
+
|
| 105 |
+
logger.debug("Dynamo transformed code saved to %s",
|
| 106 |
+
decompiled_file)
|
| 107 |
+
except Exception:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
if self.vllm_config.compilation_config.use_cudagraph and \
|
| 111 |
+
"update" in new_code.co_names:
|
| 112 |
+
import depyf
|
| 113 |
+
src = depyf.decompile(new_code)
|
| 114 |
+
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
| 115 |
+
raise RuntimeError(msg)
|
| 116 |
+
|
| 117 |
+
@contextmanager
|
| 118 |
+
def dispatch_to_code(self, index: int):
|
| 119 |
+
"""Context manager to dispatch to the compiled code.
|
| 120 |
+
Why does this work? Because Dynamo guarantees that the compiled
|
| 121 |
+
bytecode has exactly the same arguments, cell variables, and free
|
| 122 |
+
variables as the original code. Therefore we can directly switch
|
| 123 |
+
the code object in the function and call it.
|
| 124 |
+
|
| 125 |
+
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
| 126 |
+
""" # noqa
|
| 127 |
+
self.__class__.forward.__code__ = self.compiled_codes[index]
|
| 128 |
+
yield
|
| 129 |
+
self.__class__.forward.__code__ = self.original_code_object
|
.venv/lib/python3.11/site-packages/vllm/engine/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/arg_utils.cpython-311.pyc
ADDED
|
Binary file (59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_llm_engine.cpython-311.pyc
ADDED
|
Binary file (56 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/async_timeout.cpython-311.pyc
ADDED
|
Binary file (8.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/llm_engine.cpython-311.pyc
ADDED
|
Binary file (84.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (34.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/metrics_types.cpython-311.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/__pycache__/protocol.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/arg_utils.py
ADDED
|
@@ -0,0 +1,1360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import dataclasses
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
| 8 |
+
Tuple, Type, Union, cast, get_args)
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import vllm.envs as envs
|
| 13 |
+
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
| 14 |
+
DecodingConfig, DeviceConfig, HfOverrides,
|
| 15 |
+
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
| 16 |
+
ModelConfig, ModelImpl, ObservabilityConfig,
|
| 17 |
+
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
| 18 |
+
SchedulerConfig, SpeculativeConfig, TaskOption,
|
| 19 |
+
TokenizerPoolConfig, VllmConfig)
|
| 20 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 21 |
+
from vllm.logger import init_logger
|
| 22 |
+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
| 23 |
+
from vllm.transformers_utils.utils import check_gguf_file
|
| 24 |
+
from vllm.usage.usage_lib import UsageContext
|
| 25 |
+
from vllm.utils import FlexibleArgumentParser, StoreBoolean
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
| 29 |
+
|
| 30 |
+
logger = init_logger(__name__)
|
| 31 |
+
|
| 32 |
+
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
|
| 33 |
+
|
| 34 |
+
DEVICE_OPTIONS = [
|
| 35 |
+
"auto",
|
| 36 |
+
"cuda",
|
| 37 |
+
"neuron",
|
| 38 |
+
"cpu",
|
| 39 |
+
"openvino",
|
| 40 |
+
"tpu",
|
| 41 |
+
"xpu",
|
| 42 |
+
"hpu",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def nullable_str(val: str):
|
| 47 |
+
if not val or val == "None":
|
| 48 |
+
return None
|
| 49 |
+
return val
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
| 53 |
+
"""Parses a string containing comma separate key [str] to value [int]
|
| 54 |
+
pairs into a dictionary.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
val: String value to be parsed.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Dictionary with parsed values.
|
| 61 |
+
"""
|
| 62 |
+
if len(val) == 0:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
out_dict: Dict[str, int] = {}
|
| 66 |
+
for item in val.split(","):
|
| 67 |
+
kv_parts = [part.lower().strip() for part in item.split("=")]
|
| 68 |
+
if len(kv_parts) != 2:
|
| 69 |
+
raise argparse.ArgumentTypeError(
|
| 70 |
+
"Each item should be in the form KEY=VALUE")
|
| 71 |
+
key, value = kv_parts
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
parsed_value = int(value)
|
| 75 |
+
except ValueError as exc:
|
| 76 |
+
msg = f"Failed to parse value of item {key}={value}"
|
| 77 |
+
raise argparse.ArgumentTypeError(msg) from exc
|
| 78 |
+
|
| 79 |
+
if key in out_dict and out_dict[key] != parsed_value:
|
| 80 |
+
raise argparse.ArgumentTypeError(
|
| 81 |
+
f"Conflicting values specified for key: {key}")
|
| 82 |
+
out_dict[key] = parsed_value
|
| 83 |
+
|
| 84 |
+
return out_dict
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class EngineArgs:
|
| 89 |
+
"""Arguments for vLLM engine."""
|
| 90 |
+
model: str = 'facebook/opt-125m'
|
| 91 |
+
served_model_name: Optional[Union[str, List[str]]] = None
|
| 92 |
+
tokenizer: Optional[str] = None
|
| 93 |
+
task: TaskOption = "auto"
|
| 94 |
+
skip_tokenizer_init: bool = False
|
| 95 |
+
tokenizer_mode: str = 'auto'
|
| 96 |
+
trust_remote_code: bool = False
|
| 97 |
+
allowed_local_media_path: str = ""
|
| 98 |
+
download_dir: Optional[str] = None
|
| 99 |
+
load_format: str = 'auto'
|
| 100 |
+
config_format: ConfigFormat = ConfigFormat.AUTO
|
| 101 |
+
dtype: str = 'auto'
|
| 102 |
+
kv_cache_dtype: str = 'auto'
|
| 103 |
+
seed: int = 0
|
| 104 |
+
max_model_len: Optional[int] = None
|
| 105 |
+
# Note: Specifying a custom executor backend by passing a class
|
| 106 |
+
# is intended for expert use only. The API may change without
|
| 107 |
+
# notice.
|
| 108 |
+
distributed_executor_backend: Optional[Union[str,
|
| 109 |
+
Type[ExecutorBase]]] = None
|
| 110 |
+
# number of P/D disaggregation (or other disaggregation) workers
|
| 111 |
+
pipeline_parallel_size: int = 1
|
| 112 |
+
tensor_parallel_size: int = 1
|
| 113 |
+
max_parallel_loading_workers: Optional[int] = None
|
| 114 |
+
block_size: Optional[int] = None
|
| 115 |
+
enable_prefix_caching: Optional[bool] = None
|
| 116 |
+
disable_sliding_window: bool = False
|
| 117 |
+
use_v2_block_manager: bool = True
|
| 118 |
+
swap_space: float = 4 # GiB
|
| 119 |
+
cpu_offload_gb: float = 0 # GiB
|
| 120 |
+
gpu_memory_utilization: float = 0.90
|
| 121 |
+
max_num_batched_tokens: Optional[int] = None
|
| 122 |
+
max_num_seqs: Optional[int] = None
|
| 123 |
+
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
| 124 |
+
disable_log_stats: bool = False
|
| 125 |
+
revision: Optional[str] = None
|
| 126 |
+
code_revision: Optional[str] = None
|
| 127 |
+
rope_scaling: Optional[Dict[str, Any]] = None
|
| 128 |
+
rope_theta: Optional[float] = None
|
| 129 |
+
hf_overrides: Optional[HfOverrides] = None
|
| 130 |
+
tokenizer_revision: Optional[str] = None
|
| 131 |
+
quantization: Optional[str] = None
|
| 132 |
+
enforce_eager: Optional[bool] = None
|
| 133 |
+
max_seq_len_to_capture: int = 8192
|
| 134 |
+
disable_custom_all_reduce: bool = False
|
| 135 |
+
tokenizer_pool_size: int = 0
|
| 136 |
+
# Note: Specifying a tokenizer pool by passing a class
|
| 137 |
+
# is intended for expert use only. The API may change without
|
| 138 |
+
# notice.
|
| 139 |
+
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
| 140 |
+
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
| 141 |
+
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
| 142 |
+
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
| 143 |
+
disable_mm_preprocessor_cache: bool = False
|
| 144 |
+
enable_lora: bool = False
|
| 145 |
+
enable_lora_bias: bool = False
|
| 146 |
+
max_loras: int = 1
|
| 147 |
+
max_lora_rank: int = 16
|
| 148 |
+
enable_prompt_adapter: bool = False
|
| 149 |
+
max_prompt_adapters: int = 1
|
| 150 |
+
max_prompt_adapter_token: int = 0
|
| 151 |
+
fully_sharded_loras: bool = False
|
| 152 |
+
lora_extra_vocab_size: int = 256
|
| 153 |
+
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
| 154 |
+
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
| 155 |
+
max_cpu_loras: Optional[int] = None
|
| 156 |
+
device: str = 'auto'
|
| 157 |
+
num_scheduler_steps: int = 1
|
| 158 |
+
multi_step_stream_outputs: bool = True
|
| 159 |
+
ray_workers_use_nsight: bool = False
|
| 160 |
+
num_gpu_blocks_override: Optional[int] = None
|
| 161 |
+
num_lookahead_slots: int = 0
|
| 162 |
+
model_loader_extra_config: Optional[dict] = None
|
| 163 |
+
ignore_patterns: Optional[Union[str, List[str]]] = None
|
| 164 |
+
preemption_mode: Optional[str] = None
|
| 165 |
+
|
| 166 |
+
scheduler_delay_factor: float = 0.0
|
| 167 |
+
enable_chunked_prefill: Optional[bool] = None
|
| 168 |
+
|
| 169 |
+
guided_decoding_backend: str = 'xgrammar'
|
| 170 |
+
logits_processor_pattern: Optional[str] = None
|
| 171 |
+
# Speculative decoding configuration.
|
| 172 |
+
speculative_model: Optional[str] = None
|
| 173 |
+
speculative_model_quantization: Optional[str] = None
|
| 174 |
+
speculative_draft_tensor_parallel_size: Optional[int] = None
|
| 175 |
+
num_speculative_tokens: Optional[int] = None
|
| 176 |
+
speculative_disable_mqa_scorer: Optional[bool] = False
|
| 177 |
+
speculative_max_model_len: Optional[int] = None
|
| 178 |
+
speculative_disable_by_batch_size: Optional[int] = None
|
| 179 |
+
ngram_prompt_lookup_max: Optional[int] = None
|
| 180 |
+
ngram_prompt_lookup_min: Optional[int] = None
|
| 181 |
+
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
| 182 |
+
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
| 183 |
+
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
| 184 |
+
qlora_adapter_name_or_path: Optional[str] = None
|
| 185 |
+
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
| 186 |
+
|
| 187 |
+
otlp_traces_endpoint: Optional[str] = None
|
| 188 |
+
collect_detailed_traces: Optional[str] = None
|
| 189 |
+
disable_async_output_proc: bool = False
|
| 190 |
+
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
| 191 |
+
|
| 192 |
+
override_neuron_config: Optional[Dict[str, Any]] = None
|
| 193 |
+
override_pooler_config: Optional[PoolerConfig] = None
|
| 194 |
+
compilation_config: Optional[CompilationConfig] = None
|
| 195 |
+
worker_cls: str = "auto"
|
| 196 |
+
|
| 197 |
+
kv_transfer_config: Optional[KVTransferConfig] = None
|
| 198 |
+
|
| 199 |
+
generation_config: Optional[str] = None
|
| 200 |
+
override_generation_config: Optional[Dict[str, Any]] = None
|
| 201 |
+
enable_sleep_mode: bool = False
|
| 202 |
+
model_impl: str = "auto"
|
| 203 |
+
|
| 204 |
+
calculate_kv_scales: Optional[bool] = None
|
| 205 |
+
|
| 206 |
+
def __post_init__(self):
|
| 207 |
+
if not self.tokenizer:
|
| 208 |
+
self.tokenizer = self.model
|
| 209 |
+
|
| 210 |
+
# Override the default value of enable_prefix_caching if it's not set
|
| 211 |
+
# by user.
|
| 212 |
+
if self.enable_prefix_caching is None:
|
| 213 |
+
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
|
| 214 |
+
|
| 215 |
+
# Override max_num_seqs if it's not set by user.
|
| 216 |
+
if self.max_num_seqs is None:
|
| 217 |
+
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
|
| 218 |
+
|
| 219 |
+
# support `EngineArgs(compilation_config={...})`
|
| 220 |
+
# without having to manually construct a
|
| 221 |
+
# CompilationConfig object
|
| 222 |
+
if isinstance(self.compilation_config, (int, dict)):
|
| 223 |
+
self.compilation_config = CompilationConfig.from_cli(
|
| 224 |
+
str(self.compilation_config))
|
| 225 |
+
|
| 226 |
+
# Setup plugins
|
| 227 |
+
from vllm.plugins import load_general_plugins
|
| 228 |
+
load_general_plugins()
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
| 232 |
+
"""Shared CLI arguments for vLLM engine."""
|
| 233 |
+
|
| 234 |
+
# Model arguments
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
'--model',
|
| 237 |
+
type=str,
|
| 238 |
+
default=EngineArgs.model,
|
| 239 |
+
help='Name or path of the huggingface model to use.')
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
'--task',
|
| 242 |
+
default=EngineArgs.task,
|
| 243 |
+
choices=get_args(TaskOption),
|
| 244 |
+
help='The task to use the model for. Each vLLM instance only '
|
| 245 |
+
'supports one task, even if the same model can be used for '
|
| 246 |
+
'multiple tasks. When the model only supports one task, ``"auto"`` '
|
| 247 |
+
'can be used to select it; otherwise, you must specify explicitly '
|
| 248 |
+
'which task to use.')
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
'--tokenizer',
|
| 251 |
+
type=nullable_str,
|
| 252 |
+
default=EngineArgs.tokenizer,
|
| 253 |
+
help='Name or path of the huggingface tokenizer to use. '
|
| 254 |
+
'If unspecified, model name or path will be used.')
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
'--skip-tokenizer-init',
|
| 257 |
+
action='store_true',
|
| 258 |
+
help='Skip initialization of tokenizer and detokenizer.')
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
'--revision',
|
| 261 |
+
type=nullable_str,
|
| 262 |
+
default=None,
|
| 263 |
+
help='The specific model version to use. It can be a branch '
|
| 264 |
+
'name, a tag name, or a commit id. If unspecified, will use '
|
| 265 |
+
'the default version.')
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
'--code-revision',
|
| 268 |
+
type=nullable_str,
|
| 269 |
+
default=None,
|
| 270 |
+
help='The specific revision to use for the model code on '
|
| 271 |
+
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
| 272 |
+
'commit id. If unspecified, will use the default version.')
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
'--tokenizer-revision',
|
| 275 |
+
type=nullable_str,
|
| 276 |
+
default=None,
|
| 277 |
+
help='Revision of the huggingface tokenizer to use. '
|
| 278 |
+
'It can be a branch name, a tag name, or a commit id. '
|
| 279 |
+
'If unspecified, will use the default version.')
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
'--tokenizer-mode',
|
| 282 |
+
type=str,
|
| 283 |
+
default=EngineArgs.tokenizer_mode,
|
| 284 |
+
choices=['auto', 'slow', 'mistral'],
|
| 285 |
+
help='The tokenizer mode.\n\n* "auto" will use the '
|
| 286 |
+
'fast tokenizer if available.\n* "slow" will '
|
| 287 |
+
'always use the slow tokenizer. \n* '
|
| 288 |
+
'"mistral" will always use the `mistral_common` tokenizer.')
|
| 289 |
+
parser.add_argument('--trust-remote-code',
|
| 290 |
+
action='store_true',
|
| 291 |
+
help='Trust remote code from huggingface.')
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
'--allowed-local-media-path',
|
| 294 |
+
type=str,
|
| 295 |
+
help="Allowing API requests to read local images or videos "
|
| 296 |
+
"from directories specified by the server file system. "
|
| 297 |
+
"This is a security risk. "
|
| 298 |
+
"Should only be enabled in trusted environments.")
|
| 299 |
+
parser.add_argument('--download-dir',
|
| 300 |
+
type=nullable_str,
|
| 301 |
+
default=EngineArgs.download_dir,
|
| 302 |
+
help='Directory to download and load the weights, '
|
| 303 |
+
'default to the default cache dir of '
|
| 304 |
+
'huggingface.')
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
'--load-format',
|
| 307 |
+
type=str,
|
| 308 |
+
default=EngineArgs.load_format,
|
| 309 |
+
choices=[f.value for f in LoadFormat],
|
| 310 |
+
help='The format of the model weights to load.\n\n'
|
| 311 |
+
'* "auto" will try to load the weights in the safetensors format '
|
| 312 |
+
'and fall back to the pytorch bin format if safetensors format '
|
| 313 |
+
'is not available.\n'
|
| 314 |
+
'* "pt" will load the weights in the pytorch bin format.\n'
|
| 315 |
+
'* "safetensors" will load the weights in the safetensors format.\n'
|
| 316 |
+
'* "npcache" will load the weights in pytorch format and store '
|
| 317 |
+
'a numpy cache to speed up the loading.\n'
|
| 318 |
+
'* "dummy" will initialize the weights with random values, '
|
| 319 |
+
'which is mainly for profiling.\n'
|
| 320 |
+
'* "tensorizer" will load the weights using tensorizer from '
|
| 321 |
+
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
|
| 322 |
+
'section for more information.\n'
|
| 323 |
+
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
| 324 |
+
'Model Streamer \n'
|
| 325 |
+
'* "bitsandbytes" will load the weights using bitsandbytes '
|
| 326 |
+
'quantization.\n')
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
'--config-format',
|
| 329 |
+
default=EngineArgs.config_format,
|
| 330 |
+
choices=[f.value for f in ConfigFormat],
|
| 331 |
+
help='The format of the model config to load.\n\n'
|
| 332 |
+
'* "auto" will try to load the config in hf format '
|
| 333 |
+
'if available else it will try to load in mistral format ')
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
'--dtype',
|
| 336 |
+
type=str,
|
| 337 |
+
default=EngineArgs.dtype,
|
| 338 |
+
choices=[
|
| 339 |
+
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
| 340 |
+
],
|
| 341 |
+
help='Data type for model weights and activations.\n\n'
|
| 342 |
+
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
| 343 |
+
'BF16 precision for BF16 models.\n'
|
| 344 |
+
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
| 345 |
+
'* "float16" is the same as "half".\n'
|
| 346 |
+
'* "bfloat16" for a balance between precision and range.\n'
|
| 347 |
+
'* "float" is shorthand for FP32 precision.\n'
|
| 348 |
+
'* "float32" for FP32 precision.')
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
'--kv-cache-dtype',
|
| 351 |
+
type=str,
|
| 352 |
+
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
| 353 |
+
default=EngineArgs.kv_cache_dtype,
|
| 354 |
+
help='Data type for kv cache storage. If "auto", will use model '
|
| 355 |
+
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
| 356 |
+
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
| 357 |
+
parser.add_argument('--max-model-len',
|
| 358 |
+
type=int,
|
| 359 |
+
default=EngineArgs.max_model_len,
|
| 360 |
+
help='Model context length. If unspecified, will '
|
| 361 |
+
'be automatically derived from the model config.')
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
'--guided-decoding-backend',
|
| 364 |
+
type=str,
|
| 365 |
+
default='xgrammar',
|
| 366 |
+
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
|
| 367 |
+
help='Which engine will be used for guided decoding'
|
| 368 |
+
' (JSON schema / regex etc) by default. Currently support '
|
| 369 |
+
'https://github.com/outlines-dev/outlines, '
|
| 370 |
+
'https://github.com/mlc-ai/xgrammar, and '
|
| 371 |
+
'https://github.com/noamgat/lm-format-enforcer.'
|
| 372 |
+
' Can be overridden per request via guided_decoding_backend'
|
| 373 |
+
' parameter.')
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
'--logits-processor-pattern',
|
| 376 |
+
type=nullable_str,
|
| 377 |
+
default=None,
|
| 378 |
+
help='Optional regex pattern specifying valid logits processor '
|
| 379 |
+
'qualified names that can be passed with the `logits_processors` '
|
| 380 |
+
'extra completion argument. Defaults to None, which allows no '
|
| 381 |
+
'processors.')
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
'--model-impl',
|
| 384 |
+
type=str,
|
| 385 |
+
default=EngineArgs.model_impl,
|
| 386 |
+
choices=[f.value for f in ModelImpl],
|
| 387 |
+
help='Which implementation of the model to use.\n\n'
|
| 388 |
+
'* "auto" will try to use the vLLM implementation if it exists '
|
| 389 |
+
'and fall back to the Transformers implementation if no vLLM '
|
| 390 |
+
'implementation is available.\n'
|
| 391 |
+
'* "vllm" will use the vLLM model implementation.\n'
|
| 392 |
+
'* "transformers" will use the Transformers model '
|
| 393 |
+
'implementation.\n')
|
| 394 |
+
# Parallel arguments
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
'--distributed-executor-backend',
|
| 397 |
+
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
| 398 |
+
default=EngineArgs.distributed_executor_backend,
|
| 399 |
+
help='Backend to use for distributed model '
|
| 400 |
+
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
| 401 |
+
'of pipeline_parallel_size and tensor_parallel_size is less than '
|
| 402 |
+
'or equal to the number of GPUs available, "mp" will be used to '
|
| 403 |
+
'keep processing on a single host. Otherwise, this will default '
|
| 404 |
+
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
|
| 405 |
+
'only supports Ray for distributed inference.')
|
| 406 |
+
|
| 407 |
+
parser.add_argument('--pipeline-parallel-size',
|
| 408 |
+
'-pp',
|
| 409 |
+
type=int,
|
| 410 |
+
default=EngineArgs.pipeline_parallel_size,
|
| 411 |
+
help='Number of pipeline stages.')
|
| 412 |
+
parser.add_argument('--tensor-parallel-size',
|
| 413 |
+
'-tp',
|
| 414 |
+
type=int,
|
| 415 |
+
default=EngineArgs.tensor_parallel_size,
|
| 416 |
+
help='Number of tensor parallel replicas.')
|
| 417 |
+
parser.add_argument(
|
| 418 |
+
'--max-parallel-loading-workers',
|
| 419 |
+
type=int,
|
| 420 |
+
default=EngineArgs.max_parallel_loading_workers,
|
| 421 |
+
help='Load model sequentially in multiple batches, '
|
| 422 |
+
'to avoid RAM OOM when using tensor '
|
| 423 |
+
'parallel and large models.')
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
'--ray-workers-use-nsight',
|
| 426 |
+
action='store_true',
|
| 427 |
+
help='If specified, use nsight to profile Ray workers.')
|
| 428 |
+
# KV cache arguments
|
| 429 |
+
parser.add_argument('--block-size',
|
| 430 |
+
type=int,
|
| 431 |
+
default=EngineArgs.block_size,
|
| 432 |
+
choices=[8, 16, 32, 64, 128],
|
| 433 |
+
help='Token block size for contiguous chunks of '
|
| 434 |
+
'tokens. This is ignored on neuron devices and '
|
| 435 |
+
'set to ``--max-model-len``. On CUDA devices, '
|
| 436 |
+
'only block sizes up to 32 are supported. '
|
| 437 |
+
'On HPU devices, block size defaults to 128.')
|
| 438 |
+
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--enable-prefix-caching",
|
| 441 |
+
action=argparse.BooleanOptionalAction,
|
| 442 |
+
default=EngineArgs.enable_prefix_caching,
|
| 443 |
+
help="Enables automatic prefix caching. "
|
| 444 |
+
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
|
| 445 |
+
)
|
| 446 |
+
parser.add_argument('--disable-sliding-window',
|
| 447 |
+
action='store_true',
|
| 448 |
+
help='Disables sliding window, '
|
| 449 |
+
'capping to sliding window size.')
|
| 450 |
+
parser.add_argument('--use-v2-block-manager',
|
| 451 |
+
action='store_true',
|
| 452 |
+
default=True,
|
| 453 |
+
help='[DEPRECATED] block manager v1 has been '
|
| 454 |
+
'removed and SelfAttnBlockSpaceManager (i.e. '
|
| 455 |
+
'block manager v2) is now the default. '
|
| 456 |
+
'Setting this flag to True or False'
|
| 457 |
+
' has no effect on vLLM behavior.')
|
| 458 |
+
parser.add_argument(
|
| 459 |
+
'--num-lookahead-slots',
|
| 460 |
+
type=int,
|
| 461 |
+
default=EngineArgs.num_lookahead_slots,
|
| 462 |
+
help='Experimental scheduling config necessary for '
|
| 463 |
+
'speculative decoding. This will be replaced by '
|
| 464 |
+
'speculative config in the future; it is present '
|
| 465 |
+
'to enable correctness tests until then.')
|
| 466 |
+
|
| 467 |
+
parser.add_argument('--seed',
|
| 468 |
+
type=int,
|
| 469 |
+
default=EngineArgs.seed,
|
| 470 |
+
help='Random seed for operations.')
|
| 471 |
+
parser.add_argument('--swap-space',
|
| 472 |
+
type=float,
|
| 473 |
+
default=EngineArgs.swap_space,
|
| 474 |
+
help='CPU swap space size (GiB) per GPU.')
|
| 475 |
+
parser.add_argument(
|
| 476 |
+
'--cpu-offload-gb',
|
| 477 |
+
type=float,
|
| 478 |
+
default=0,
|
| 479 |
+
help='The space in GiB to offload to CPU, per GPU. '
|
| 480 |
+
'Default is 0, which means no offloading. Intuitively, '
|
| 481 |
+
'this argument can be seen as a virtual way to increase '
|
| 482 |
+
'the GPU memory size. For example, if you have one 24 GB '
|
| 483 |
+
'GPU and set this to 10, virtually you can think of it as '
|
| 484 |
+
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
|
| 485 |
+
'which requires at least 26GB GPU memory. Note that this '
|
| 486 |
+
'requires fast CPU-GPU interconnect, as part of the model is '
|
| 487 |
+
'loaded from CPU memory to GPU memory on the fly in each '
|
| 488 |
+
'model forward pass.')
|
| 489 |
+
parser.add_argument(
|
| 490 |
+
'--gpu-memory-utilization',
|
| 491 |
+
type=float,
|
| 492 |
+
default=EngineArgs.gpu_memory_utilization,
|
| 493 |
+
help='The fraction of GPU memory to be used for the model '
|
| 494 |
+
'executor, which can range from 0 to 1. For example, a value of '
|
| 495 |
+
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
|
| 496 |
+
'will use the default value of 0.9. This is a per-instance '
|
| 497 |
+
'limit, and only applies to the current vLLM instance.'
|
| 498 |
+
'It does not matter if you have another vLLM instance running '
|
| 499 |
+
'on the same GPU. For example, if you have two vLLM instances '
|
| 500 |
+
'running on the same GPU, you can set the GPU memory utilization '
|
| 501 |
+
'to 0.5 for each instance.')
|
| 502 |
+
parser.add_argument(
|
| 503 |
+
'--num-gpu-blocks-override',
|
| 504 |
+
type=int,
|
| 505 |
+
default=None,
|
| 506 |
+
help='If specified, ignore GPU profiling result and use this number'
|
| 507 |
+
' of GPU blocks. Used for testing preemption.')
|
| 508 |
+
parser.add_argument('--max-num-batched-tokens',
|
| 509 |
+
type=int,
|
| 510 |
+
default=EngineArgs.max_num_batched_tokens,
|
| 511 |
+
help='Maximum number of batched tokens per '
|
| 512 |
+
'iteration.')
|
| 513 |
+
parser.add_argument('--max-num-seqs',
|
| 514 |
+
type=int,
|
| 515 |
+
default=EngineArgs.max_num_seqs,
|
| 516 |
+
help='Maximum number of sequences per iteration.')
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
'--max-logprobs',
|
| 519 |
+
type=int,
|
| 520 |
+
default=EngineArgs.max_logprobs,
|
| 521 |
+
help=('Max number of log probs to return logprobs is specified in'
|
| 522 |
+
' SamplingParams.'))
|
| 523 |
+
parser.add_argument('--disable-log-stats',
|
| 524 |
+
action='store_true',
|
| 525 |
+
help='Disable logging statistics.')
|
| 526 |
+
# Quantization settings.
|
| 527 |
+
parser.add_argument('--quantization',
|
| 528 |
+
'-q',
|
| 529 |
+
type=nullable_str,
|
| 530 |
+
choices=[*QUANTIZATION_METHODS, None],
|
| 531 |
+
default=EngineArgs.quantization,
|
| 532 |
+
help='Method used to quantize the weights. If '
|
| 533 |
+
'None, we first check the `quantization_config` '
|
| 534 |
+
'attribute in the model config file. If that is '
|
| 535 |
+
'None, we assume the model weights are not '
|
| 536 |
+
'quantized and use `dtype` to determine the data '
|
| 537 |
+
'type of the weights.')
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
'--rope-scaling',
|
| 540 |
+
default=None,
|
| 541 |
+
type=json.loads,
|
| 542 |
+
help='RoPE scaling configuration in JSON format. '
|
| 543 |
+
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
|
| 544 |
+
parser.add_argument('--rope-theta',
|
| 545 |
+
default=None,
|
| 546 |
+
type=float,
|
| 547 |
+
help='RoPE theta. Use with `rope_scaling`. In '
|
| 548 |
+
'some cases, changing the RoPE theta improves the '
|
| 549 |
+
'performance of the scaled model.')
|
| 550 |
+
parser.add_argument('--hf-overrides',
|
| 551 |
+
type=json.loads,
|
| 552 |
+
default=EngineArgs.hf_overrides,
|
| 553 |
+
help='Extra arguments for the HuggingFace config. '
|
| 554 |
+
'This should be a JSON string that will be '
|
| 555 |
+
'parsed into a dictionary.')
|
| 556 |
+
parser.add_argument('--enforce-eager',
|
| 557 |
+
action='store_true',
|
| 558 |
+
help='Always use eager-mode PyTorch. If False, '
|
| 559 |
+
'will use eager mode and CUDA graph in hybrid '
|
| 560 |
+
'for maximal performance and flexibility.')
|
| 561 |
+
parser.add_argument('--max-seq-len-to-capture',
|
| 562 |
+
type=int,
|
| 563 |
+
default=EngineArgs.max_seq_len_to_capture,
|
| 564 |
+
help='Maximum sequence length covered by CUDA '
|
| 565 |
+
'graphs. When a sequence has context length '
|
| 566 |
+
'larger than this, we fall back to eager mode. '
|
| 567 |
+
'Additionally for encoder-decoder models, if the '
|
| 568 |
+
'sequence length of the encoder input is larger '
|
| 569 |
+
'than this, we fall back to the eager mode.')
|
| 570 |
+
parser.add_argument('--disable-custom-all-reduce',
|
| 571 |
+
action='store_true',
|
| 572 |
+
default=EngineArgs.disable_custom_all_reduce,
|
| 573 |
+
help='See ParallelConfig.')
|
| 574 |
+
parser.add_argument('--tokenizer-pool-size',
|
| 575 |
+
type=int,
|
| 576 |
+
default=EngineArgs.tokenizer_pool_size,
|
| 577 |
+
help='Size of tokenizer pool to use for '
|
| 578 |
+
'asynchronous tokenization. If 0, will '
|
| 579 |
+
'use synchronous tokenization.')
|
| 580 |
+
parser.add_argument('--tokenizer-pool-type',
|
| 581 |
+
type=str,
|
| 582 |
+
default=EngineArgs.tokenizer_pool_type,
|
| 583 |
+
help='Type of tokenizer pool to use for '
|
| 584 |
+
'asynchronous tokenization. Ignored '
|
| 585 |
+
'if tokenizer_pool_size is 0.')
|
| 586 |
+
parser.add_argument('--tokenizer-pool-extra-config',
|
| 587 |
+
type=nullable_str,
|
| 588 |
+
default=EngineArgs.tokenizer_pool_extra_config,
|
| 589 |
+
help='Extra config for tokenizer pool. '
|
| 590 |
+
'This should be a JSON string that will be '
|
| 591 |
+
'parsed into a dictionary. Ignored if '
|
| 592 |
+
'tokenizer_pool_size is 0.')
|
| 593 |
+
|
| 594 |
+
# Multimodal related configs
|
| 595 |
+
parser.add_argument(
|
| 596 |
+
'--limit-mm-per-prompt',
|
| 597 |
+
type=nullable_kvs,
|
| 598 |
+
default=EngineArgs.limit_mm_per_prompt,
|
| 599 |
+
# The default value is given in
|
| 600 |
+
# MultiModalRegistry.init_mm_limits_per_prompt
|
| 601 |
+
help=('For each multimodal plugin, limit how many '
|
| 602 |
+
'input instances to allow for each prompt. '
|
| 603 |
+
'Expects a comma-separated list of items, '
|
| 604 |
+
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
| 605 |
+
'images and 2 videos per prompt. Defaults to 1 for '
|
| 606 |
+
'each modality.'))
|
| 607 |
+
parser.add_argument(
|
| 608 |
+
'--mm-processor-kwargs',
|
| 609 |
+
default=None,
|
| 610 |
+
type=json.loads,
|
| 611 |
+
help=('Overrides for the multimodal input mapping/processing, '
|
| 612 |
+
'e.g., image processor. For example: ``{"num_crops": 4}``.'))
|
| 613 |
+
parser.add_argument(
|
| 614 |
+
'--disable-mm-preprocessor-cache',
|
| 615 |
+
action='store_true',
|
| 616 |
+
help='If true, then disables caching of the multi-modal '
|
| 617 |
+
'preprocessor/mapper. (not recommended)')
|
| 618 |
+
|
| 619 |
+
# LoRA related configs
|
| 620 |
+
parser.add_argument('--enable-lora',
|
| 621 |
+
action='store_true',
|
| 622 |
+
help='If True, enable handling of LoRA adapters.')
|
| 623 |
+
parser.add_argument('--enable-lora-bias',
|
| 624 |
+
action='store_true',
|
| 625 |
+
help='If True, enable bias for LoRA adapters.')
|
| 626 |
+
parser.add_argument('--max-loras',
|
| 627 |
+
type=int,
|
| 628 |
+
default=EngineArgs.max_loras,
|
| 629 |
+
help='Max number of LoRAs in a single batch.')
|
| 630 |
+
parser.add_argument('--max-lora-rank',
|
| 631 |
+
type=int,
|
| 632 |
+
default=EngineArgs.max_lora_rank,
|
| 633 |
+
help='Max LoRA rank.')
|
| 634 |
+
parser.add_argument(
|
| 635 |
+
'--lora-extra-vocab-size',
|
| 636 |
+
type=int,
|
| 637 |
+
default=EngineArgs.lora_extra_vocab_size,
|
| 638 |
+
help=('Maximum size of extra vocabulary that can be '
|
| 639 |
+
'present in a LoRA adapter (added to the base '
|
| 640 |
+
'model vocabulary).'))
|
| 641 |
+
parser.add_argument(
|
| 642 |
+
'--lora-dtype',
|
| 643 |
+
type=str,
|
| 644 |
+
default=EngineArgs.lora_dtype,
|
| 645 |
+
choices=['auto', 'float16', 'bfloat16'],
|
| 646 |
+
help=('Data type for LoRA. If auto, will default to '
|
| 647 |
+
'base model dtype.'))
|
| 648 |
+
parser.add_argument(
|
| 649 |
+
'--long-lora-scaling-factors',
|
| 650 |
+
type=nullable_str,
|
| 651 |
+
default=EngineArgs.long_lora_scaling_factors,
|
| 652 |
+
help=('Specify multiple scaling factors (which can '
|
| 653 |
+
'be different from base model scaling factor '
|
| 654 |
+
'- see eg. Long LoRA) to allow for multiple '
|
| 655 |
+
'LoRA adapters trained with those scaling '
|
| 656 |
+
'factors to be used at the same time. If not '
|
| 657 |
+
'specified, only adapters trained with the '
|
| 658 |
+
'base model scaling factor are allowed.'))
|
| 659 |
+
parser.add_argument(
|
| 660 |
+
'--max-cpu-loras',
|
| 661 |
+
type=int,
|
| 662 |
+
default=EngineArgs.max_cpu_loras,
|
| 663 |
+
help=('Maximum number of LoRAs to store in CPU memory. '
|
| 664 |
+
'Must be >= than max_loras. '
|
| 665 |
+
'Defaults to max_loras.'))
|
| 666 |
+
parser.add_argument(
|
| 667 |
+
'--fully-sharded-loras',
|
| 668 |
+
action='store_true',
|
| 669 |
+
help=('By default, only half of the LoRA computation is '
|
| 670 |
+
'sharded with tensor parallelism. '
|
| 671 |
+
'Enabling this will use the fully sharded layers. '
|
| 672 |
+
'At high sequence length, max rank or '
|
| 673 |
+
'tensor parallel size, this is likely faster.'))
|
| 674 |
+
parser.add_argument('--enable-prompt-adapter',
|
| 675 |
+
action='store_true',
|
| 676 |
+
help='If True, enable handling of PromptAdapters.')
|
| 677 |
+
parser.add_argument('--max-prompt-adapters',
|
| 678 |
+
type=int,
|
| 679 |
+
default=EngineArgs.max_prompt_adapters,
|
| 680 |
+
help='Max number of PromptAdapters in a batch.')
|
| 681 |
+
parser.add_argument('--max-prompt-adapter-token',
|
| 682 |
+
type=int,
|
| 683 |
+
default=EngineArgs.max_prompt_adapter_token,
|
| 684 |
+
help='Max number of PromptAdapters tokens')
|
| 685 |
+
parser.add_argument("--device",
|
| 686 |
+
type=str,
|
| 687 |
+
default=EngineArgs.device,
|
| 688 |
+
choices=DEVICE_OPTIONS,
|
| 689 |
+
help='Device type for vLLM execution.')
|
| 690 |
+
parser.add_argument('--num-scheduler-steps',
|
| 691 |
+
type=int,
|
| 692 |
+
default=1,
|
| 693 |
+
help=('Maximum number of forward steps per '
|
| 694 |
+
'scheduler call.'))
|
| 695 |
+
|
| 696 |
+
parser.add_argument(
|
| 697 |
+
'--multi-step-stream-outputs',
|
| 698 |
+
action=StoreBoolean,
|
| 699 |
+
default=EngineArgs.multi_step_stream_outputs,
|
| 700 |
+
nargs="?",
|
| 701 |
+
const="True",
|
| 702 |
+
help='If False, then multi-step will stream outputs at the end '
|
| 703 |
+
'of all steps')
|
| 704 |
+
parser.add_argument(
|
| 705 |
+
'--scheduler-delay-factor',
|
| 706 |
+
type=float,
|
| 707 |
+
default=EngineArgs.scheduler_delay_factor,
|
| 708 |
+
help='Apply a delay (of delay factor multiplied by previous '
|
| 709 |
+
'prompt latency) before scheduling next prompt.')
|
| 710 |
+
parser.add_argument(
|
| 711 |
+
'--enable-chunked-prefill',
|
| 712 |
+
action=StoreBoolean,
|
| 713 |
+
default=EngineArgs.enable_chunked_prefill,
|
| 714 |
+
nargs="?",
|
| 715 |
+
const="True",
|
| 716 |
+
help='If set, the prefill requests can be chunked based on the '
|
| 717 |
+
'max_num_batched_tokens.')
|
| 718 |
+
|
| 719 |
+
parser.add_argument(
|
| 720 |
+
'--speculative-model',
|
| 721 |
+
type=nullable_str,
|
| 722 |
+
default=EngineArgs.speculative_model,
|
| 723 |
+
help=
|
| 724 |
+
'The name of the draft model to be used in speculative decoding.')
|
| 725 |
+
# Quantization settings for speculative model.
|
| 726 |
+
parser.add_argument(
|
| 727 |
+
'--speculative-model-quantization',
|
| 728 |
+
type=nullable_str,
|
| 729 |
+
choices=[*QUANTIZATION_METHODS, None],
|
| 730 |
+
default=EngineArgs.speculative_model_quantization,
|
| 731 |
+
help='Method used to quantize the weights of speculative model. '
|
| 732 |
+
'If None, we first check the `quantization_config` '
|
| 733 |
+
'attribute in the model config file. If that is '
|
| 734 |
+
'None, we assume the model weights are not '
|
| 735 |
+
'quantized and use `dtype` to determine the data '
|
| 736 |
+
'type of the weights.')
|
| 737 |
+
parser.add_argument(
|
| 738 |
+
'--num-speculative-tokens',
|
| 739 |
+
type=int,
|
| 740 |
+
default=EngineArgs.num_speculative_tokens,
|
| 741 |
+
help='The number of speculative tokens to sample from '
|
| 742 |
+
'the draft model in speculative decoding.')
|
| 743 |
+
parser.add_argument(
|
| 744 |
+
'--speculative-disable-mqa-scorer',
|
| 745 |
+
action='store_true',
|
| 746 |
+
help=
|
| 747 |
+
'If set to True, the MQA scorer will be disabled in speculative '
|
| 748 |
+
' and fall back to batch expansion')
|
| 749 |
+
parser.add_argument(
|
| 750 |
+
'--speculative-draft-tensor-parallel-size',
|
| 751 |
+
'-spec-draft-tp',
|
| 752 |
+
type=int,
|
| 753 |
+
default=EngineArgs.speculative_draft_tensor_parallel_size,
|
| 754 |
+
help='Number of tensor parallel replicas for '
|
| 755 |
+
'the draft model in speculative decoding.')
|
| 756 |
+
|
| 757 |
+
parser.add_argument(
|
| 758 |
+
'--speculative-max-model-len',
|
| 759 |
+
type=int,
|
| 760 |
+
default=EngineArgs.speculative_max_model_len,
|
| 761 |
+
help='The maximum sequence length supported by the '
|
| 762 |
+
'draft model. Sequences over this length will skip '
|
| 763 |
+
'speculation.')
|
| 764 |
+
|
| 765 |
+
parser.add_argument(
|
| 766 |
+
'--speculative-disable-by-batch-size',
|
| 767 |
+
type=int,
|
| 768 |
+
default=EngineArgs.speculative_disable_by_batch_size,
|
| 769 |
+
help='Disable speculative decoding for new incoming requests '
|
| 770 |
+
'if the number of enqueue requests is larger than this value.')
|
| 771 |
+
|
| 772 |
+
parser.add_argument(
|
| 773 |
+
'--ngram-prompt-lookup-max',
|
| 774 |
+
type=int,
|
| 775 |
+
default=EngineArgs.ngram_prompt_lookup_max,
|
| 776 |
+
help='Max size of window for ngram prompt lookup in speculative '
|
| 777 |
+
'decoding.')
|
| 778 |
+
|
| 779 |
+
parser.add_argument(
|
| 780 |
+
'--ngram-prompt-lookup-min',
|
| 781 |
+
type=int,
|
| 782 |
+
default=EngineArgs.ngram_prompt_lookup_min,
|
| 783 |
+
help='Min size of window for ngram prompt lookup in speculative '
|
| 784 |
+
'decoding.')
|
| 785 |
+
|
| 786 |
+
parser.add_argument(
|
| 787 |
+
'--spec-decoding-acceptance-method',
|
| 788 |
+
type=str,
|
| 789 |
+
default=EngineArgs.spec_decoding_acceptance_method,
|
| 790 |
+
choices=['rejection_sampler', 'typical_acceptance_sampler'],
|
| 791 |
+
help='Specify the acceptance method to use during draft token '
|
| 792 |
+
'verification in speculative decoding. Two types of acceptance '
|
| 793 |
+
'routines are supported: '
|
| 794 |
+
'1) RejectionSampler which does not allow changing the '
|
| 795 |
+
'acceptance rate of draft tokens, '
|
| 796 |
+
'2) TypicalAcceptanceSampler which is configurable, allowing for '
|
| 797 |
+
'a higher acceptance rate at the cost of lower quality, '
|
| 798 |
+
'and vice versa.')
|
| 799 |
+
|
| 800 |
+
parser.add_argument(
|
| 801 |
+
'--typical-acceptance-sampler-posterior-threshold',
|
| 802 |
+
type=float,
|
| 803 |
+
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
|
| 804 |
+
help='Set the lower bound threshold for the posterior '
|
| 805 |
+
'probability of a token to be accepted. This threshold is '
|
| 806 |
+
'used by the TypicalAcceptanceSampler to make sampling decisions '
|
| 807 |
+
'during speculative decoding. Defaults to 0.09')
|
| 808 |
+
|
| 809 |
+
parser.add_argument(
|
| 810 |
+
'--typical-acceptance-sampler-posterior-alpha',
|
| 811 |
+
type=float,
|
| 812 |
+
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
|
| 813 |
+
help='A scaling factor for the entropy-based threshold for token '
|
| 814 |
+
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
|
| 815 |
+
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
| 816 |
+
'i.e. 0.3')
|
| 817 |
+
|
| 818 |
+
parser.add_argument(
|
| 819 |
+
'--disable-logprobs-during-spec-decoding',
|
| 820 |
+
action=StoreBoolean,
|
| 821 |
+
default=EngineArgs.disable_logprobs_during_spec_decoding,
|
| 822 |
+
nargs="?",
|
| 823 |
+
const="True",
|
| 824 |
+
help='If set to True, token log probabilities are not returned '
|
| 825 |
+
'during speculative decoding. If set to False, log probabilities '
|
| 826 |
+
'are returned according to the settings in SamplingParams. If '
|
| 827 |
+
'not specified, it defaults to True. Disabling log probabilities '
|
| 828 |
+
'during speculative decoding reduces latency by skipping logprob '
|
| 829 |
+
'calculation in proposal sampling, target sampling, and after '
|
| 830 |
+
'accepted tokens are determined.')
|
| 831 |
+
|
| 832 |
+
parser.add_argument('--model-loader-extra-config',
|
| 833 |
+
type=nullable_str,
|
| 834 |
+
default=EngineArgs.model_loader_extra_config,
|
| 835 |
+
help='Extra config for model loader. '
|
| 836 |
+
'This will be passed to the model loader '
|
| 837 |
+
'corresponding to the chosen load_format. '
|
| 838 |
+
'This should be a JSON string that will be '
|
| 839 |
+
'parsed into a dictionary.')
|
| 840 |
+
parser.add_argument(
|
| 841 |
+
'--ignore-patterns',
|
| 842 |
+
action="append",
|
| 843 |
+
type=str,
|
| 844 |
+
default=[],
|
| 845 |
+
help="The pattern(s) to ignore when loading the model."
|
| 846 |
+
"Default to `original/**/*` to avoid repeated loading of llama's "
|
| 847 |
+
"checkpoints.")
|
| 848 |
+
parser.add_argument(
|
| 849 |
+
'--preemption-mode',
|
| 850 |
+
type=str,
|
| 851 |
+
default=None,
|
| 852 |
+
help='If \'recompute\', the engine performs preemption by '
|
| 853 |
+
'recomputing; If \'swap\', the engine performs preemption by '
|
| 854 |
+
'block swapping.')
|
| 855 |
+
|
| 856 |
+
parser.add_argument(
|
| 857 |
+
"--served-model-name",
|
| 858 |
+
nargs="+",
|
| 859 |
+
type=str,
|
| 860 |
+
default=None,
|
| 861 |
+
help="The model name(s) used in the API. If multiple "
|
| 862 |
+
"names are provided, the server will respond to any "
|
| 863 |
+
"of the provided names. The model name in the model "
|
| 864 |
+
"field of a response will be the first name in this "
|
| 865 |
+
"list. If not specified, the model name will be the "
|
| 866 |
+
"same as the ``--model`` argument. Noted that this name(s) "
|
| 867 |
+
"will also be used in `model_name` tag content of "
|
| 868 |
+
"prometheus metrics, if multiple names provided, metrics "
|
| 869 |
+
"tag will take the first one.")
|
| 870 |
+
parser.add_argument('--qlora-adapter-name-or-path',
|
| 871 |
+
type=str,
|
| 872 |
+
default=None,
|
| 873 |
+
help='Name or path of the QLoRA adapter.')
|
| 874 |
+
|
| 875 |
+
parser.add_argument(
|
| 876 |
+
'--otlp-traces-endpoint',
|
| 877 |
+
type=str,
|
| 878 |
+
default=None,
|
| 879 |
+
help='Target URL to which OpenTelemetry traces will be sent.')
|
| 880 |
+
parser.add_argument(
|
| 881 |
+
'--collect-detailed-traces',
|
| 882 |
+
type=str,
|
| 883 |
+
default=None,
|
| 884 |
+
help="Valid choices are " +
|
| 885 |
+
",".join(ALLOWED_DETAILED_TRACE_MODULES) +
|
| 886 |
+
". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
|
| 887 |
+
" set. If set, it will collect detailed traces for the specified "
|
| 888 |
+
"modules. This involves use of possibly costly and or blocking "
|
| 889 |
+
"operations and hence might have a performance impact.")
|
| 890 |
+
|
| 891 |
+
parser.add_argument(
|
| 892 |
+
'--disable-async-output-proc',
|
| 893 |
+
action='store_true',
|
| 894 |
+
default=EngineArgs.disable_async_output_proc,
|
| 895 |
+
help="Disable async output processing. This may result in "
|
| 896 |
+
"lower performance.")
|
| 897 |
+
|
| 898 |
+
parser.add_argument(
|
| 899 |
+
'--scheduling-policy',
|
| 900 |
+
choices=['fcfs', 'priority'],
|
| 901 |
+
default="fcfs",
|
| 902 |
+
help='The scheduling policy to use. "fcfs" (first come first served'
|
| 903 |
+
', i.e. requests are handled in order of arrival; default) '
|
| 904 |
+
'or "priority" (requests are handled based on given '
|
| 905 |
+
'priority (lower value means earlier handling) and time of '
|
| 906 |
+
'arrival deciding any ties).')
|
| 907 |
+
|
| 908 |
+
parser.add_argument(
|
| 909 |
+
'--override-neuron-config',
|
| 910 |
+
type=json.loads,
|
| 911 |
+
default=None,
|
| 912 |
+
help="Override or set neuron device configuration. "
|
| 913 |
+
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
|
| 914 |
+
parser.add_argument(
|
| 915 |
+
'--override-pooler-config',
|
| 916 |
+
type=PoolerConfig.from_json,
|
| 917 |
+
default=None,
|
| 918 |
+
help="Override or set the pooling method for pooling models. "
|
| 919 |
+
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
|
| 920 |
+
|
| 921 |
+
parser.add_argument('--compilation-config',
|
| 922 |
+
'-O',
|
| 923 |
+
type=CompilationConfig.from_cli,
|
| 924 |
+
default=None,
|
| 925 |
+
help='torch.compile configuration for the model.'
|
| 926 |
+
'When it is a number (0, 1, 2, 3), it will be '
|
| 927 |
+
'interpreted as the optimization level.\n'
|
| 928 |
+
'NOTE: level 0 is the default level without '
|
| 929 |
+
'any optimization. level 1 and 2 are for internal '
|
| 930 |
+
'testing only. level 3 is the recommended level '
|
| 931 |
+
'for production.\n'
|
| 932 |
+
'To specify the full compilation config, '
|
| 933 |
+
'use a JSON string.\n'
|
| 934 |
+
'Following the convention of traditional '
|
| 935 |
+
'compilers, using -O without space is also '
|
| 936 |
+
'supported. -O3 is equivalent to -O 3.')
|
| 937 |
+
|
| 938 |
+
parser.add_argument('--kv-transfer-config',
|
| 939 |
+
type=KVTransferConfig.from_cli,
|
| 940 |
+
default=None,
|
| 941 |
+
help='The configurations for distributed KV cache '
|
| 942 |
+
'transfer. Should be a JSON string.')
|
| 943 |
+
|
| 944 |
+
parser.add_argument(
|
| 945 |
+
'--worker-cls',
|
| 946 |
+
type=str,
|
| 947 |
+
default="auto",
|
| 948 |
+
help='The worker class to use for distributed execution.')
|
| 949 |
+
parser.add_argument(
|
| 950 |
+
"--generation-config",
|
| 951 |
+
type=nullable_str,
|
| 952 |
+
default=None,
|
| 953 |
+
help="The folder path to the generation config. "
|
| 954 |
+
"Defaults to None, no generation config is loaded, vLLM defaults "
|
| 955 |
+
"will be used. If set to 'auto', the generation config will be "
|
| 956 |
+
"loaded from model path. If set to a folder path, the generation "
|
| 957 |
+
"config will be loaded from the specified folder path. If "
|
| 958 |
+
"`max_new_tokens` is specified in generation config, then "
|
| 959 |
+
"it sets a server-wide limit on the number of output tokens "
|
| 960 |
+
"for all requests.")
|
| 961 |
+
|
| 962 |
+
parser.add_argument(
|
| 963 |
+
"--override-generation-config",
|
| 964 |
+
type=json.loads,
|
| 965 |
+
default=None,
|
| 966 |
+
help="Overrides or sets generation config in JSON format. "
|
| 967 |
+
"e.g. ``{\"temperature\": 0.5}``. If used with "
|
| 968 |
+
"--generation-config=auto, the override parameters will be merged "
|
| 969 |
+
"with the default config from the model. If generation-config is "
|
| 970 |
+
"None, only the override parameters are used.")
|
| 971 |
+
|
| 972 |
+
parser.add_argument("--enable-sleep-mode",
|
| 973 |
+
action="store_true",
|
| 974 |
+
default=False,
|
| 975 |
+
help="Enable sleep mode for the engine. "
|
| 976 |
+
"(only cuda platform is supported)")
|
| 977 |
+
|
| 978 |
+
parser.add_argument(
|
| 979 |
+
'--calculate-kv-scales',
|
| 980 |
+
action='store_true',
|
| 981 |
+
help='This enables dynamic calculation of '
|
| 982 |
+
'k_scale and v_scale when kv-cache-dtype is fp8. '
|
| 983 |
+
'If calculate-kv-scales is false, the scales will '
|
| 984 |
+
'be loaded from the model checkpoint if available. '
|
| 985 |
+
'Otherwise, the scales will default to 1.0.')
|
| 986 |
+
|
| 987 |
+
return parser
|
| 988 |
+
|
| 989 |
+
@classmethod
|
| 990 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 991 |
+
# Get the list of attributes of this dataclass.
|
| 992 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 993 |
+
# Set the attributes from the parsed arguments.
|
| 994 |
+
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
| 995 |
+
return engine_args
|
| 996 |
+
|
| 997 |
+
def create_model_config(self) -> ModelConfig:
|
| 998 |
+
return ModelConfig(
|
| 999 |
+
model=self.model,
|
| 1000 |
+
task=self.task,
|
| 1001 |
+
# We know this is not None because we set it in __post_init__
|
| 1002 |
+
tokenizer=cast(str, self.tokenizer),
|
| 1003 |
+
tokenizer_mode=self.tokenizer_mode,
|
| 1004 |
+
trust_remote_code=self.trust_remote_code,
|
| 1005 |
+
allowed_local_media_path=self.allowed_local_media_path,
|
| 1006 |
+
dtype=self.dtype,
|
| 1007 |
+
seed=self.seed,
|
| 1008 |
+
revision=self.revision,
|
| 1009 |
+
code_revision=self.code_revision,
|
| 1010 |
+
rope_scaling=self.rope_scaling,
|
| 1011 |
+
rope_theta=self.rope_theta,
|
| 1012 |
+
hf_overrides=self.hf_overrides,
|
| 1013 |
+
tokenizer_revision=self.tokenizer_revision,
|
| 1014 |
+
max_model_len=self.max_model_len,
|
| 1015 |
+
quantization=self.quantization,
|
| 1016 |
+
enforce_eager=self.enforce_eager,
|
| 1017 |
+
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
| 1018 |
+
max_logprobs=self.max_logprobs,
|
| 1019 |
+
disable_sliding_window=self.disable_sliding_window,
|
| 1020 |
+
skip_tokenizer_init=self.skip_tokenizer_init,
|
| 1021 |
+
served_model_name=self.served_model_name,
|
| 1022 |
+
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
| 1023 |
+
use_async_output_proc=not self.disable_async_output_proc,
|
| 1024 |
+
config_format=self.config_format,
|
| 1025 |
+
mm_processor_kwargs=self.mm_processor_kwargs,
|
| 1026 |
+
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
| 1027 |
+
override_neuron_config=self.override_neuron_config,
|
| 1028 |
+
override_pooler_config=self.override_pooler_config,
|
| 1029 |
+
logits_processor_pattern=self.logits_processor_pattern,
|
| 1030 |
+
generation_config=self.generation_config,
|
| 1031 |
+
override_generation_config=self.override_generation_config,
|
| 1032 |
+
enable_sleep_mode=self.enable_sleep_mode,
|
| 1033 |
+
model_impl=self.model_impl,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
def create_load_config(self) -> LoadConfig:
|
| 1037 |
+
return LoadConfig(
|
| 1038 |
+
load_format=self.load_format,
|
| 1039 |
+
download_dir=self.download_dir,
|
| 1040 |
+
model_loader_extra_config=self.model_loader_extra_config,
|
| 1041 |
+
ignore_patterns=self.ignore_patterns,
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
def create_engine_config(self,
|
| 1045 |
+
usage_context: Optional[UsageContext] = None
|
| 1046 |
+
) -> VllmConfig:
|
| 1047 |
+
if envs.VLLM_USE_V1:
|
| 1048 |
+
self._override_v1_engine_args(usage_context)
|
| 1049 |
+
|
| 1050 |
+
# gguf file needs a specific model loader and doesn't use hf_repo
|
| 1051 |
+
if check_gguf_file(self.model):
|
| 1052 |
+
self.quantization = self.load_format = "gguf"
|
| 1053 |
+
|
| 1054 |
+
# bitsandbytes quantization needs a specific model loader
|
| 1055 |
+
# so we make sure the quant method and the load format are consistent
|
| 1056 |
+
if (self.quantization == "bitsandbytes" or
|
| 1057 |
+
self.qlora_adapter_name_or_path is not None) and \
|
| 1058 |
+
self.load_format != "bitsandbytes":
|
| 1059 |
+
raise ValueError(
|
| 1060 |
+
"BitsAndBytes quantization and QLoRA adapter only support "
|
| 1061 |
+
f"'bitsandbytes' load format, but got {self.load_format}")
|
| 1062 |
+
|
| 1063 |
+
if (self.load_format == "bitsandbytes" or
|
| 1064 |
+
self.qlora_adapter_name_or_path is not None) and \
|
| 1065 |
+
self.quantization != "bitsandbytes":
|
| 1066 |
+
raise ValueError(
|
| 1067 |
+
"BitsAndBytes load format and QLoRA adapter only support "
|
| 1068 |
+
f"'bitsandbytes' quantization, but got {self.quantization}")
|
| 1069 |
+
|
| 1070 |
+
assert self.cpu_offload_gb >= 0, (
|
| 1071 |
+
"CPU offload space must be non-negative"
|
| 1072 |
+
f", but got {self.cpu_offload_gb}")
|
| 1073 |
+
|
| 1074 |
+
device_config = DeviceConfig(device=self.device)
|
| 1075 |
+
model_config = self.create_model_config()
|
| 1076 |
+
|
| 1077 |
+
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
|
| 1078 |
+
and self.enable_prefix_caching):
|
| 1079 |
+
logger.warning("--enable-prefix-caching is currently not "
|
| 1080 |
+
"supported for multimodal models in v0 and "
|
| 1081 |
+
"has been disabled.")
|
| 1082 |
+
self.enable_prefix_caching = False
|
| 1083 |
+
|
| 1084 |
+
cache_config = CacheConfig(
|
| 1085 |
+
block_size=self.block_size,
|
| 1086 |
+
gpu_memory_utilization=self.gpu_memory_utilization,
|
| 1087 |
+
swap_space=self.swap_space,
|
| 1088 |
+
cache_dtype=self.kv_cache_dtype,
|
| 1089 |
+
is_attention_free=model_config.is_attention_free,
|
| 1090 |
+
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
| 1091 |
+
sliding_window=model_config.get_sliding_window(),
|
| 1092 |
+
enable_prefix_caching=self.enable_prefix_caching,
|
| 1093 |
+
cpu_offload_gb=self.cpu_offload_gb,
|
| 1094 |
+
calculate_kv_scales=self.calculate_kv_scales,
|
| 1095 |
+
)
|
| 1096 |
+
parallel_config = ParallelConfig(
|
| 1097 |
+
pipeline_parallel_size=self.pipeline_parallel_size,
|
| 1098 |
+
tensor_parallel_size=self.tensor_parallel_size,
|
| 1099 |
+
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
| 1100 |
+
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
| 1101 |
+
tokenizer_pool_config=TokenizerPoolConfig.create_config(
|
| 1102 |
+
self.tokenizer_pool_size,
|
| 1103 |
+
self.tokenizer_pool_type,
|
| 1104 |
+
self.tokenizer_pool_extra_config,
|
| 1105 |
+
),
|
| 1106 |
+
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
| 1107 |
+
distributed_executor_backend=self.distributed_executor_backend,
|
| 1108 |
+
worker_cls=self.worker_cls,
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
max_model_len = model_config.max_model_len
|
| 1112 |
+
use_long_context = max_model_len > 32768
|
| 1113 |
+
if self.enable_chunked_prefill is None:
|
| 1114 |
+
# If not explicitly set, enable chunked prefill by default for
|
| 1115 |
+
# long context (> 32K) models. This is to avoid OOM errors in the
|
| 1116 |
+
# initial memory profiling phase.
|
| 1117 |
+
|
| 1118 |
+
# For multimodal models, chunked prefill is disabled by default in
|
| 1119 |
+
# V0, but enabled by design in V1
|
| 1120 |
+
if model_config.is_multimodal_model:
|
| 1121 |
+
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
|
| 1122 |
+
|
| 1123 |
+
elif use_long_context:
|
| 1124 |
+
is_gpu = device_config.device_type == "cuda"
|
| 1125 |
+
use_sliding_window = (model_config.get_sliding_window()
|
| 1126 |
+
is not None)
|
| 1127 |
+
use_spec_decode = self.speculative_model is not None
|
| 1128 |
+
from vllm.platforms import current_platform
|
| 1129 |
+
if (is_gpu and not use_sliding_window and not use_spec_decode
|
| 1130 |
+
and not self.enable_lora
|
| 1131 |
+
and not self.enable_prompt_adapter
|
| 1132 |
+
and model_config.runner_type != "pooling"
|
| 1133 |
+
and not current_platform.is_rocm()):
|
| 1134 |
+
self.enable_chunked_prefill = True
|
| 1135 |
+
logger.warning(
|
| 1136 |
+
"Chunked prefill is enabled by default for models with "
|
| 1137 |
+
"max_model_len > 32K. Currently, chunked prefill might "
|
| 1138 |
+
"not work with some features or models. If you "
|
| 1139 |
+
"encounter any issues, please disable chunked prefill "
|
| 1140 |
+
"by setting --enable-chunked-prefill=False.")
|
| 1141 |
+
if self.enable_chunked_prefill is None:
|
| 1142 |
+
self.enable_chunked_prefill = False
|
| 1143 |
+
|
| 1144 |
+
if not self.enable_chunked_prefill and use_long_context:
|
| 1145 |
+
logger.warning(
|
| 1146 |
+
"The model has a long context length (%s). This may cause OOM "
|
| 1147 |
+
"errors during the initial memory profiling phase, or result "
|
| 1148 |
+
"in low performance due to small KV cache space. Consider "
|
| 1149 |
+
"setting --max-model-len to a smaller value.", max_model_len)
|
| 1150 |
+
elif (self.enable_chunked_prefill
|
| 1151 |
+
and model_config.runner_type == "pooling"):
|
| 1152 |
+
msg = "Chunked prefill is not supported for pooling models"
|
| 1153 |
+
raise ValueError(msg)
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
| 1157 |
+
target_model_config=model_config,
|
| 1158 |
+
target_parallel_config=parallel_config,
|
| 1159 |
+
target_dtype=self.dtype,
|
| 1160 |
+
speculative_model=self.speculative_model,
|
| 1161 |
+
speculative_model_quantization = \
|
| 1162 |
+
self.speculative_model_quantization,
|
| 1163 |
+
speculative_draft_tensor_parallel_size = \
|
| 1164 |
+
self.speculative_draft_tensor_parallel_size,
|
| 1165 |
+
num_speculative_tokens=self.num_speculative_tokens,
|
| 1166 |
+
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
| 1167 |
+
speculative_disable_by_batch_size=self.
|
| 1168 |
+
speculative_disable_by_batch_size,
|
| 1169 |
+
speculative_max_model_len=self.speculative_max_model_len,
|
| 1170 |
+
enable_chunked_prefill=self.enable_chunked_prefill,
|
| 1171 |
+
disable_log_stats=self.disable_log_stats,
|
| 1172 |
+
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
| 1173 |
+
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
| 1174 |
+
draft_token_acceptance_method=\
|
| 1175 |
+
self.spec_decoding_acceptance_method,
|
| 1176 |
+
typical_acceptance_sampler_posterior_threshold=self.
|
| 1177 |
+
typical_acceptance_sampler_posterior_threshold,
|
| 1178 |
+
typical_acceptance_sampler_posterior_alpha=self.
|
| 1179 |
+
typical_acceptance_sampler_posterior_alpha,
|
| 1180 |
+
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
| 1184 |
+
# If the feature combo become valid
|
| 1185 |
+
if self.num_scheduler_steps > 1:
|
| 1186 |
+
if speculative_config is not None:
|
| 1187 |
+
raise ValueError("Speculative decoding is not supported with "
|
| 1188 |
+
"multi-step (--num-scheduler-steps > 1)")
|
| 1189 |
+
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
|
| 1190 |
+
raise ValueError("Multi-Step Chunked-Prefill is not supported "
|
| 1191 |
+
"for pipeline-parallel-size > 1")
|
| 1192 |
+
from vllm.platforms import current_platform
|
| 1193 |
+
if current_platform.is_cpu():
|
| 1194 |
+
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
|
| 1195 |
+
"currently not supported for CPUs and has been "
|
| 1196 |
+
"disabled.")
|
| 1197 |
+
self.num_scheduler_steps = 1
|
| 1198 |
+
|
| 1199 |
+
# make sure num_lookahead_slots is set the higher value depending on
|
| 1200 |
+
# if we are using speculative decoding or multi-step
|
| 1201 |
+
num_lookahead_slots = max(self.num_lookahead_slots,
|
| 1202 |
+
self.num_scheduler_steps - 1)
|
| 1203 |
+
num_lookahead_slots = num_lookahead_slots \
|
| 1204 |
+
if speculative_config is None \
|
| 1205 |
+
else speculative_config.num_lookahead_slots
|
| 1206 |
+
|
| 1207 |
+
if not self.use_v2_block_manager:
|
| 1208 |
+
logger.warning(
|
| 1209 |
+
"[DEPRECATED] Block manager v1 has been removed, "
|
| 1210 |
+
"and setting --use-v2-block-manager to True or False has "
|
| 1211 |
+
"no effect on vLLM behavior. Please remove "
|
| 1212 |
+
"--use-v2-block-manager in your engine argument. "
|
| 1213 |
+
"If your use case is not supported by "
|
| 1214 |
+
"SelfAttnBlockSpaceManager (i.e. block manager v2),"
|
| 1215 |
+
" please file an issue with detailed information.")
|
| 1216 |
+
|
| 1217 |
+
scheduler_config = SchedulerConfig(
|
| 1218 |
+
runner_type=model_config.runner_type,
|
| 1219 |
+
max_num_batched_tokens=self.max_num_batched_tokens,
|
| 1220 |
+
max_num_seqs=self.max_num_seqs,
|
| 1221 |
+
max_model_len=model_config.max_model_len,
|
| 1222 |
+
num_lookahead_slots=num_lookahead_slots,
|
| 1223 |
+
delay_factor=self.scheduler_delay_factor,
|
| 1224 |
+
enable_chunked_prefill=self.enable_chunked_prefill,
|
| 1225 |
+
is_multimodal_model=model_config.is_multimodal_model,
|
| 1226 |
+
preemption_mode=self.preemption_mode,
|
| 1227 |
+
num_scheduler_steps=self.num_scheduler_steps,
|
| 1228 |
+
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
| 1229 |
+
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
| 1230 |
+
and parallel_config.use_ray),
|
| 1231 |
+
policy=self.scheduling_policy)
|
| 1232 |
+
lora_config = LoRAConfig(
|
| 1233 |
+
bias_enabled=self.enable_lora_bias,
|
| 1234 |
+
max_lora_rank=self.max_lora_rank,
|
| 1235 |
+
max_loras=self.max_loras,
|
| 1236 |
+
fully_sharded_loras=self.fully_sharded_loras,
|
| 1237 |
+
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
| 1238 |
+
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
| 1239 |
+
lora_dtype=self.lora_dtype,
|
| 1240 |
+
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
| 1241 |
+
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
| 1242 |
+
|
| 1243 |
+
if self.qlora_adapter_name_or_path is not None and \
|
| 1244 |
+
self.qlora_adapter_name_or_path != "":
|
| 1245 |
+
if self.model_loader_extra_config is None:
|
| 1246 |
+
self.model_loader_extra_config = {}
|
| 1247 |
+
self.model_loader_extra_config[
|
| 1248 |
+
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
| 1249 |
+
|
| 1250 |
+
load_config = self.create_load_config()
|
| 1251 |
+
|
| 1252 |
+
prompt_adapter_config = PromptAdapterConfig(
|
| 1253 |
+
max_prompt_adapters=self.max_prompt_adapters,
|
| 1254 |
+
max_prompt_adapter_token=self.max_prompt_adapter_token) \
|
| 1255 |
+
if self.enable_prompt_adapter else None
|
| 1256 |
+
|
| 1257 |
+
decoding_config = DecodingConfig(
|
| 1258 |
+
guided_decoding_backend=self.guided_decoding_backend)
|
| 1259 |
+
|
| 1260 |
+
detailed_trace_modules = []
|
| 1261 |
+
if self.collect_detailed_traces is not None:
|
| 1262 |
+
detailed_trace_modules = self.collect_detailed_traces.split(",")
|
| 1263 |
+
for m in detailed_trace_modules:
|
| 1264 |
+
if m not in ALLOWED_DETAILED_TRACE_MODULES:
|
| 1265 |
+
raise ValueError(
|
| 1266 |
+
f"Invalid module {m} in collect_detailed_traces. "
|
| 1267 |
+
f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
|
| 1268 |
+
observability_config = ObservabilityConfig(
|
| 1269 |
+
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
| 1270 |
+
collect_model_forward_time="model" in detailed_trace_modules
|
| 1271 |
+
or "all" in detailed_trace_modules,
|
| 1272 |
+
collect_model_execute_time="worker" in detailed_trace_modules
|
| 1273 |
+
or "all" in detailed_trace_modules,
|
| 1274 |
+
)
|
| 1275 |
+
|
| 1276 |
+
config = VllmConfig(
|
| 1277 |
+
model_config=model_config,
|
| 1278 |
+
cache_config=cache_config,
|
| 1279 |
+
parallel_config=parallel_config,
|
| 1280 |
+
scheduler_config=scheduler_config,
|
| 1281 |
+
device_config=device_config,
|
| 1282 |
+
lora_config=lora_config,
|
| 1283 |
+
speculative_config=speculative_config,
|
| 1284 |
+
load_config=load_config,
|
| 1285 |
+
decoding_config=decoding_config,
|
| 1286 |
+
observability_config=observability_config,
|
| 1287 |
+
prompt_adapter_config=prompt_adapter_config,
|
| 1288 |
+
compilation_config=self.compilation_config,
|
| 1289 |
+
kv_transfer_config=self.kv_transfer_config,
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
if envs.VLLM_USE_V1:
|
| 1293 |
+
self._override_v1_engine_config(config)
|
| 1294 |
+
return config
|
| 1295 |
+
|
| 1296 |
+
def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
|
| 1297 |
+
"""
|
| 1298 |
+
Override the EngineArgs's args based on the usage context for V1.
|
| 1299 |
+
"""
|
| 1300 |
+
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
| 1301 |
+
|
| 1302 |
+
# V1 always uses chunked prefills.
|
| 1303 |
+
self.enable_chunked_prefill = True
|
| 1304 |
+
# When no user override, set the default values based on the usage
|
| 1305 |
+
# context.
|
| 1306 |
+
# Use different default values for different hardware.
|
| 1307 |
+
from vllm.platforms import current_platform
|
| 1308 |
+
device_name = current_platform.get_device_name().lower()
|
| 1309 |
+
if "h100" in device_name or "h200" in device_name:
|
| 1310 |
+
# For H100 and H200, we use larger default values.
|
| 1311 |
+
default_max_num_batched_tokens = {
|
| 1312 |
+
UsageContext.LLM_CLASS: 16384,
|
| 1313 |
+
UsageContext.OPENAI_API_SERVER: 8192,
|
| 1314 |
+
}
|
| 1315 |
+
else:
|
| 1316 |
+
# TODO(woosuk): Tune the default values for other hardware.
|
| 1317 |
+
default_max_num_batched_tokens = {
|
| 1318 |
+
UsageContext.LLM_CLASS: 8192,
|
| 1319 |
+
UsageContext.OPENAI_API_SERVER: 2048,
|
| 1320 |
+
}
|
| 1321 |
+
|
| 1322 |
+
if (self.max_num_batched_tokens is None
|
| 1323 |
+
and usage_context in default_max_num_batched_tokens):
|
| 1324 |
+
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
| 1325 |
+
usage_context]
|
| 1326 |
+
logger.warning(
|
| 1327 |
+
"Setting max_num_batched_tokens to %d for %s usage context.",
|
| 1328 |
+
self.max_num_batched_tokens, usage_context.value)
|
| 1329 |
+
|
| 1330 |
+
def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
|
| 1331 |
+
"""
|
| 1332 |
+
Override the EngineConfig's configs based on the usage context for V1.
|
| 1333 |
+
"""
|
| 1334 |
+
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
@dataclass
|
| 1338 |
+
class AsyncEngineArgs(EngineArgs):
|
| 1339 |
+
"""Arguments for asynchronous vLLM engine."""
|
| 1340 |
+
disable_log_requests: bool = False
|
| 1341 |
+
|
| 1342 |
+
@staticmethod
|
| 1343 |
+
def add_cli_args(parser: FlexibleArgumentParser,
|
| 1344 |
+
async_args_only: bool = False) -> FlexibleArgumentParser:
|
| 1345 |
+
if not async_args_only:
|
| 1346 |
+
parser = EngineArgs.add_cli_args(parser)
|
| 1347 |
+
parser.add_argument('--disable-log-requests',
|
| 1348 |
+
action='store_true',
|
| 1349 |
+
help='Disable logging requests.')
|
| 1350 |
+
return parser
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
# These functions are used by sphinx to build the documentation
|
| 1354 |
+
def _engine_args_parser():
|
| 1355 |
+
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
| 1356 |
+
|
| 1357 |
+
|
| 1358 |
+
def _async_engine_args_parser():
|
| 1359 |
+
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
|
| 1360 |
+
async_args_only=True)
|
.venv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py
ADDED
|
@@ -0,0 +1,1198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import copy
|
| 5 |
+
import time
|
| 6 |
+
import weakref
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
| 9 |
+
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
|
| 10 |
+
from weakref import ReferenceType
|
| 11 |
+
|
| 12 |
+
from typing_extensions import deprecated
|
| 13 |
+
|
| 14 |
+
import vllm.envs as envs
|
| 15 |
+
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
| 16 |
+
ParallelConfig, SchedulerConfig, VllmConfig)
|
| 17 |
+
from vllm.core.scheduler import SchedulerOutputs
|
| 18 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
| 19 |
+
from vllm.engine.async_timeout import asyncio_timeout
|
| 20 |
+
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
| 21 |
+
from vllm.engine.metrics_types import StatLoggerBase
|
| 22 |
+
from vllm.engine.protocol import EngineClient
|
| 23 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 24 |
+
from vllm.inputs import PromptType
|
| 25 |
+
from vllm.inputs.preprocess import InputPreprocessor
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
from vllm.lora.request import LoRARequest
|
| 28 |
+
from vllm.model_executor.guided_decoding import (
|
| 29 |
+
get_guided_decoding_logits_processor)
|
| 30 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 31 |
+
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
| 32 |
+
from vllm.pooling_params import PoolingParams
|
| 33 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 34 |
+
from vllm.sampling_params import SamplingParams
|
| 35 |
+
from vllm.sequence import ExecuteModelRequest
|
| 36 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 37 |
+
from vllm.usage.usage_lib import UsageContext
|
| 38 |
+
from vllm.utils import deprecate_kwargs, weak_bind
|
| 39 |
+
|
| 40 |
+
logger = init_logger(__name__)
|
| 41 |
+
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AsyncEngineDeadError(RuntimeError):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _log_task_completion(task: asyncio.Task,
|
| 49 |
+
error_callback: Callable[[Exception], None]) -> None:
|
| 50 |
+
"""This function is only intended for the `engine.run_engine_loop()` task.
|
| 51 |
+
|
| 52 |
+
In particular, that task runs a `while True` loop that can only exit if
|
| 53 |
+
there is an exception.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
exception = None
|
| 57 |
+
try:
|
| 58 |
+
return_value = task.result()
|
| 59 |
+
raise AssertionError(
|
| 60 |
+
f"The engine background task should never finish without an "
|
| 61 |
+
f"exception. {return_value}")
|
| 62 |
+
except asyncio.exceptions.CancelledError:
|
| 63 |
+
# We assume that if the task is cancelled, we are gracefully shutting
|
| 64 |
+
# down. This should only happen on program exit.
|
| 65 |
+
logger.info("Engine is gracefully shutting down.")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
exception = e
|
| 68 |
+
logger.error("Engine background task failed", exc_info=e)
|
| 69 |
+
error_callback(exception)
|
| 70 |
+
raise AsyncEngineDeadError(
|
| 71 |
+
"Task finished unexpectedly. This should never happen! "
|
| 72 |
+
"Please open an issue on Github. See stack trace above for the "
|
| 73 |
+
"actual cause.") from e
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
STOP_ITERATION = Exception() # Sentinel
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AsyncStream:
|
| 80 |
+
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
| 81 |
+
that can be iterated over asynchronously via an async generator."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
| 84 |
+
self.request_id = request_id
|
| 85 |
+
self._cancel = cancel
|
| 86 |
+
self._queue: asyncio.Queue = asyncio.Queue()
|
| 87 |
+
self._finished = False
|
| 88 |
+
|
| 89 |
+
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
| 90 |
+
Exception]) -> None:
|
| 91 |
+
if not self._finished:
|
| 92 |
+
self._queue.put_nowait(item)
|
| 93 |
+
|
| 94 |
+
def finish(
|
| 95 |
+
self,
|
| 96 |
+
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
|
| 97 |
+
) -> None:
|
| 98 |
+
if not self._finished:
|
| 99 |
+
self._finished = True
|
| 100 |
+
self._queue.put_nowait(
|
| 101 |
+
exception if self._is_raisable(exception) else STOP_ITERATION)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def finished(self) -> bool:
|
| 105 |
+
return self._finished
|
| 106 |
+
|
| 107 |
+
async def generator(
|
| 108 |
+
self
|
| 109 |
+
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
| 110 |
+
try:
|
| 111 |
+
while True:
|
| 112 |
+
result = await self._queue.get()
|
| 113 |
+
if self._is_raisable(result):
|
| 114 |
+
if result == STOP_ITERATION:
|
| 115 |
+
return
|
| 116 |
+
raise result
|
| 117 |
+
yield result
|
| 118 |
+
except GeneratorExit:
|
| 119 |
+
self._cancel(self.request_id)
|
| 120 |
+
raise asyncio.CancelledError from None
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def _is_raisable(value: Any):
|
| 124 |
+
return isinstance(value, BaseException) or \
|
| 125 |
+
(isinstance(value, type) and \
|
| 126 |
+
issubclass(value, BaseException))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class RequestTracker:
|
| 130 |
+
"""Synchronous abstraction for tracking requests."""
|
| 131 |
+
|
| 132 |
+
def __init__(self) -> None:
|
| 133 |
+
self._request_streams: Dict[str, AsyncStream] = {}
|
| 134 |
+
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
|
| 135 |
+
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
| 136 |
+
dict]] = asyncio.Queue()
|
| 137 |
+
self.new_requests_event = asyncio.Event()
|
| 138 |
+
|
| 139 |
+
def __contains__(self, item):
|
| 140 |
+
return item in self._request_streams
|
| 141 |
+
|
| 142 |
+
def __len__(self) -> int:
|
| 143 |
+
return len(self._request_streams)
|
| 144 |
+
|
| 145 |
+
def propagate_exception(self,
|
| 146 |
+
exc: Exception,
|
| 147 |
+
request_id: Optional[str] = None) -> None:
|
| 148 |
+
"""Propagate an exception to request streams
|
| 149 |
+
(all if request_id is None)."""
|
| 150 |
+
if request_id is not None:
|
| 151 |
+
self.abort_request(request_id, exception=exc)
|
| 152 |
+
else:
|
| 153 |
+
# NB: tuple() used here because self.abort_request pops the stream
|
| 154 |
+
# out of self._request_streams, so we can't iterate on it directly
|
| 155 |
+
for rid in tuple(self._request_streams.keys()):
|
| 156 |
+
self.abort_request(rid, exception=exc)
|
| 157 |
+
|
| 158 |
+
def process_request_output(self,
|
| 159 |
+
request_output: Union[RequestOutput,
|
| 160 |
+
PoolingRequestOutput],
|
| 161 |
+
*,
|
| 162 |
+
verbose: bool = False) -> None:
|
| 163 |
+
"""Process a request output from the engine."""
|
| 164 |
+
request_id = request_output.request_id
|
| 165 |
+
finished = request_output.finished
|
| 166 |
+
|
| 167 |
+
if finished:
|
| 168 |
+
stream = self._request_streams.pop(request_id, None)
|
| 169 |
+
else:
|
| 170 |
+
stream = self._request_streams.get(request_id)
|
| 171 |
+
# Guard against a KeyError which can occur if the request was aborted
|
| 172 |
+
# while the output was generated
|
| 173 |
+
if stream is not None:
|
| 174 |
+
stream.put(request_output)
|
| 175 |
+
if finished:
|
| 176 |
+
stream.finish()
|
| 177 |
+
|
| 178 |
+
if verbose and finished:
|
| 179 |
+
logger.info("Finished request %s.", request_id)
|
| 180 |
+
|
| 181 |
+
def process_exception(self,
|
| 182 |
+
request_id: str,
|
| 183 |
+
exception: BaseException,
|
| 184 |
+
*,
|
| 185 |
+
verbose: bool = False) -> None:
|
| 186 |
+
"""Propagate an exception from the engine."""
|
| 187 |
+
if verbose:
|
| 188 |
+
logger.info("Finished request %s.", request_id)
|
| 189 |
+
self.abort_request(request_id, exception=exception)
|
| 190 |
+
|
| 191 |
+
def add_request(self,
|
| 192 |
+
request_id: str,
|
| 193 |
+
*,
|
| 194 |
+
verbose: bool = False,
|
| 195 |
+
**engine_add_request_kwargs) -> AsyncStream:
|
| 196 |
+
"""Add a request to be sent to the engine on the next background
|
| 197 |
+
loop iteration."""
|
| 198 |
+
if request_id in self._request_streams:
|
| 199 |
+
raise KeyError(f"Request {request_id} already exists.")
|
| 200 |
+
|
| 201 |
+
abort_request = partial(self.abort_request, verbose=verbose)
|
| 202 |
+
stream = AsyncStream(request_id, abort_request)
|
| 203 |
+
self._new_requests.put_nowait((stream, {
|
| 204 |
+
"request_id": request_id,
|
| 205 |
+
**engine_add_request_kwargs
|
| 206 |
+
}))
|
| 207 |
+
|
| 208 |
+
self.new_requests_event.set()
|
| 209 |
+
|
| 210 |
+
if verbose:
|
| 211 |
+
logger.info("Added request %s.", request_id)
|
| 212 |
+
|
| 213 |
+
return stream
|
| 214 |
+
|
| 215 |
+
def abort_request(self,
|
| 216 |
+
request_id: str,
|
| 217 |
+
*,
|
| 218 |
+
exception: Optional[Union[BaseException,
|
| 219 |
+
Type[BaseException]]] = None,
|
| 220 |
+
verbose: bool = False) -> None:
|
| 221 |
+
"""Abort a request during next background loop iteration."""
|
| 222 |
+
if verbose:
|
| 223 |
+
logger.info("Aborted request %s.", request_id)
|
| 224 |
+
|
| 225 |
+
self._aborted_requests.put_nowait(request_id)
|
| 226 |
+
|
| 227 |
+
stream = self._request_streams.pop(request_id, None)
|
| 228 |
+
if stream is not None:
|
| 229 |
+
stream.finish(exception=exception)
|
| 230 |
+
|
| 231 |
+
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
|
| 232 |
+
"""Get the new requests and finished requests to be
|
| 233 |
+
sent to the engine."""
|
| 234 |
+
new_requests: List[Dict] = []
|
| 235 |
+
finished_requests: Set[str] = set()
|
| 236 |
+
|
| 237 |
+
while not self._aborted_requests.empty():
|
| 238 |
+
request_id = self._aborted_requests.get_nowait()
|
| 239 |
+
finished_requests.add(request_id)
|
| 240 |
+
|
| 241 |
+
while not self._new_requests.empty():
|
| 242 |
+
stream, new_request = self._new_requests.get_nowait()
|
| 243 |
+
request_id = stream.request_id
|
| 244 |
+
if request_id in finished_requests:
|
| 245 |
+
# The request has already been aborted.
|
| 246 |
+
stream.finish(asyncio.CancelledError)
|
| 247 |
+
finished_requests.discard(request_id)
|
| 248 |
+
else:
|
| 249 |
+
self._request_streams[request_id] = stream
|
| 250 |
+
new_requests.append(new_request)
|
| 251 |
+
|
| 252 |
+
return new_requests, finished_requests
|
| 253 |
+
|
| 254 |
+
async def wait_for_new_requests(self):
|
| 255 |
+
if not self.has_new_requests():
|
| 256 |
+
await self.new_requests_event.wait()
|
| 257 |
+
self.new_requests_event.clear()
|
| 258 |
+
|
| 259 |
+
def has_new_requests(self):
|
| 260 |
+
return not self._new_requests.empty()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class _AsyncLLMEngine(LLMEngine):
|
| 264 |
+
"""Extension of LLMEngine to add async methods."""
|
| 265 |
+
|
| 266 |
+
def __init__(self, *args, **kwargs):
|
| 267 |
+
super().__init__(*args, **kwargs)
|
| 268 |
+
|
| 269 |
+
async def step_async(
|
| 270 |
+
self, virtual_engine: int
|
| 271 |
+
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
| 272 |
+
"""Performs one decoding iteration and returns newly generated results.
|
| 273 |
+
The workers are ran asynchronously if possible.
|
| 274 |
+
|
| 275 |
+
This function performs one decoding iteration of the engine. It first
|
| 276 |
+
schedules the sequences to be executed in the next iteration and the
|
| 277 |
+
token blocks to be swapped in/out/copy. Then, it executes the model
|
| 278 |
+
and updates the scheduler with the model outputs. Finally, it decodes
|
| 279 |
+
the sequences and returns the newly generated results.
|
| 280 |
+
"""
|
| 281 |
+
# these are cached outputs from previous iterations. None if on first
|
| 282 |
+
# iteration
|
| 283 |
+
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
| 284 |
+
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
| 285 |
+
scheduler_outputs = cached_outputs.scheduler_outputs
|
| 286 |
+
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
| 287 |
+
|
| 288 |
+
ctx = self.scheduler_contexts[virtual_engine]
|
| 289 |
+
|
| 290 |
+
# Clear outputs for each new scheduler iteration
|
| 291 |
+
ctx.request_outputs.clear()
|
| 292 |
+
|
| 293 |
+
# skip the scheduler if there are any remaining steps in the seq groups.
|
| 294 |
+
# This ensures that the scheduler is only called again when the current
|
| 295 |
+
# batch has completed.
|
| 296 |
+
if not self._has_remaining_steps(seq_group_metadata_list):
|
| 297 |
+
|
| 298 |
+
# Schedule iteration
|
| 299 |
+
(seq_group_metadata_list, scheduler_outputs,
|
| 300 |
+
allow_async_output_proc
|
| 301 |
+
) = self.scheduler[virtual_engine].schedule()
|
| 302 |
+
|
| 303 |
+
ctx.seq_group_metadata_list = seq_group_metadata_list
|
| 304 |
+
ctx.scheduler_outputs = scheduler_outputs
|
| 305 |
+
|
| 306 |
+
finished_requests_ids = self.scheduler[
|
| 307 |
+
virtual_engine].get_and_reset_finished_requests_ids()
|
| 308 |
+
|
| 309 |
+
# Maybe switch from async mode to sync mode
|
| 310 |
+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
| 311 |
+
self._process_model_outputs(ctx=ctx)
|
| 312 |
+
|
| 313 |
+
if (self.scheduler_config.is_multi_step
|
| 314 |
+
and scheduler_outputs.num_lookahead_slots > 0):
|
| 315 |
+
# cache the scheduler outputs for the next iteration if we have
|
| 316 |
+
# lookahead slots
|
| 317 |
+
self._cache_scheduler_outputs_for_multi_step(
|
| 318 |
+
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
| 319 |
+
allow_async_output_proc)
|
| 320 |
+
else:
|
| 321 |
+
finished_requests_ids = list()
|
| 322 |
+
|
| 323 |
+
assert seq_group_metadata_list is not None
|
| 324 |
+
assert scheduler_outputs is not None
|
| 325 |
+
|
| 326 |
+
if not scheduler_outputs.is_empty():
|
| 327 |
+
|
| 328 |
+
# Check if we have a cached last_output from the previous iteration.
|
| 329 |
+
# For supporting PP this is probably the best way to pass the
|
| 330 |
+
# sampled_token_ids, as a separate broadcast over all the PP stages
|
| 331 |
+
# will cause one virtual engine's microbatch to block the pipeline.
|
| 332 |
+
last_sampled_token_ids = \
|
| 333 |
+
self._get_last_sampled_token_ids(virtual_engine)
|
| 334 |
+
|
| 335 |
+
execute_model_req = ExecuteModelRequest(
|
| 336 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 337 |
+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
| 338 |
+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
| 339 |
+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
| 340 |
+
virtual_engine=virtual_engine,
|
| 341 |
+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
| 342 |
+
running_queue_size=scheduler_outputs.running_queue_size,
|
| 343 |
+
finished_requests_ids=finished_requests_ids,
|
| 344 |
+
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
| 345 |
+
# to each of the non-last PP stages for in-place prepare_input.
|
| 346 |
+
last_sampled_token_ids=last_sampled_token_ids)
|
| 347 |
+
|
| 348 |
+
if allow_async_output_proc:
|
| 349 |
+
execute_model_req.async_callback = self.async_callbacks[
|
| 350 |
+
virtual_engine]
|
| 351 |
+
|
| 352 |
+
# Execute the model.
|
| 353 |
+
outputs = await self.model_executor.execute_model_async(
|
| 354 |
+
execute_model_req)
|
| 355 |
+
|
| 356 |
+
# we need to do this here so that last step's sampled_token_ids can
|
| 357 |
+
# be passed to the next iteration for PP.
|
| 358 |
+
if self.scheduler_config.is_multi_step:
|
| 359 |
+
self._update_cached_scheduler_output(virtual_engine, outputs)
|
| 360 |
+
else:
|
| 361 |
+
if len(ctx.output_queue) > 0:
|
| 362 |
+
self._process_model_outputs(ctx=ctx)
|
| 363 |
+
outputs = []
|
| 364 |
+
|
| 365 |
+
# Finish the current step for all the sequence groups.
|
| 366 |
+
if self.scheduler_config.is_multi_step:
|
| 367 |
+
for seq_group in seq_group_metadata_list:
|
| 368 |
+
seq_group.finish_step()
|
| 369 |
+
|
| 370 |
+
if not self._has_remaining_steps(seq_group_metadata_list):
|
| 371 |
+
# Clear the cache if we have finished all the steps
|
| 372 |
+
if self.scheduler_config.is_multi_step:
|
| 373 |
+
self.cached_scheduler_outputs[
|
| 374 |
+
virtual_engine] = SchedulerOutputState()
|
| 375 |
+
|
| 376 |
+
# is_first_step_output is True only when the num_steps of all
|
| 377 |
+
# the sequences are 1. When the num_steps > 1,
|
| 378 |
+
# multi_step_model_runner does the first-step output append.
|
| 379 |
+
is_first_step_output: bool = False if not seq_group_metadata_list \
|
| 380 |
+
else seq_group_metadata_list[0].state.num_steps == 1
|
| 381 |
+
|
| 382 |
+
ctx.append_output(outputs=outputs,
|
| 383 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 384 |
+
scheduler_outputs=scheduler_outputs,
|
| 385 |
+
is_async=allow_async_output_proc,
|
| 386 |
+
is_last_step=True,
|
| 387 |
+
is_first_step_output=is_first_step_output)
|
| 388 |
+
|
| 389 |
+
if outputs and allow_async_output_proc:
|
| 390 |
+
assert len(
|
| 391 |
+
outputs
|
| 392 |
+
) == 1, "Async postprocessor expects only a single output set"
|
| 393 |
+
self._advance_to_next_step(
|
| 394 |
+
outputs[0], seq_group_metadata_list,
|
| 395 |
+
scheduler_outputs.scheduled_seq_groups)
|
| 396 |
+
|
| 397 |
+
if not allow_async_output_proc:
|
| 398 |
+
self._process_model_outputs(ctx=ctx)
|
| 399 |
+
|
| 400 |
+
# Log stats.
|
| 401 |
+
self.do_log_stats(scheduler_outputs, outputs)
|
| 402 |
+
|
| 403 |
+
# Tracing
|
| 404 |
+
self.do_tracing(scheduler_outputs)
|
| 405 |
+
|
| 406 |
+
else:
|
| 407 |
+
# Multi-step case
|
| 408 |
+
return ctx.request_outputs
|
| 409 |
+
|
| 410 |
+
if not self.has_unfinished_requests():
|
| 411 |
+
# Drain async postprocessor (if exists)
|
| 412 |
+
if len(ctx.output_queue) > 0:
|
| 413 |
+
self._process_model_outputs(ctx=ctx)
|
| 414 |
+
assert len(ctx.output_queue) == 0
|
| 415 |
+
|
| 416 |
+
return ctx.request_outputs
|
| 417 |
+
|
| 418 |
+
async def stop_remote_worker_execution_loop_async(self) -> None:
|
| 419 |
+
"""Stop the remote worker execution loop."""
|
| 420 |
+
await self.model_executor.stop_remote_worker_execution_loop_async()
|
| 421 |
+
|
| 422 |
+
async def get_tokenizer_async(self,
|
| 423 |
+
lora_request: Optional[LoRARequest] = None
|
| 424 |
+
) -> AnyTokenizer:
|
| 425 |
+
return await (
|
| 426 |
+
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
|
| 427 |
+
|
| 428 |
+
@overload
|
| 429 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 430 |
+
async def add_request_async(
|
| 431 |
+
self,
|
| 432 |
+
request_id: str,
|
| 433 |
+
*,
|
| 434 |
+
inputs: PromptType,
|
| 435 |
+
params: Union[SamplingParams, PoolingParams],
|
| 436 |
+
arrival_time: Optional[float] = None,
|
| 437 |
+
lora_request: Optional[LoRARequest] = None,
|
| 438 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 439 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 440 |
+
priority: int = 0,
|
| 441 |
+
) -> None:
|
| 442 |
+
...
|
| 443 |
+
|
| 444 |
+
@overload
|
| 445 |
+
async def add_request_async(
|
| 446 |
+
self,
|
| 447 |
+
request_id: str,
|
| 448 |
+
prompt: PromptType,
|
| 449 |
+
params: Union[SamplingParams, PoolingParams],
|
| 450 |
+
arrival_time: Optional[float] = None,
|
| 451 |
+
lora_request: Optional[LoRARequest] = None,
|
| 452 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 453 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 454 |
+
priority: int = 0,
|
| 455 |
+
) -> None:
|
| 456 |
+
...
|
| 457 |
+
|
| 458 |
+
@deprecate_kwargs(
|
| 459 |
+
"inputs",
|
| 460 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 461 |
+
)
|
| 462 |
+
async def add_request_async(
|
| 463 |
+
self,
|
| 464 |
+
request_id: str,
|
| 465 |
+
prompt: Optional[PromptType] = None,
|
| 466 |
+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
| 467 |
+
arrival_time: Optional[float] = None,
|
| 468 |
+
lora_request: Optional[LoRARequest] = None,
|
| 469 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 470 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 471 |
+
priority: int = 0,
|
| 472 |
+
*,
|
| 473 |
+
inputs: Optional[PromptType] = None, # DEPRECATED
|
| 474 |
+
) -> None:
|
| 475 |
+
"""Async version of :meth:`add_request`."""
|
| 476 |
+
if inputs is not None:
|
| 477 |
+
prompt = inputs
|
| 478 |
+
assert prompt is not None and params is not None
|
| 479 |
+
|
| 480 |
+
if lora_request is not None and not self.lora_config:
|
| 481 |
+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
| 482 |
+
"not enabled!")
|
| 483 |
+
if priority != 0 and not self.scheduler_config.policy == "priority":
|
| 484 |
+
raise ValueError(f"Got priority {priority} but "
|
| 485 |
+
"Priority scheduling is not enabled.")
|
| 486 |
+
if arrival_time is None:
|
| 487 |
+
arrival_time = time.time()
|
| 488 |
+
|
| 489 |
+
if self.tokenizer is not None:
|
| 490 |
+
tokenizer = await self.get_tokenizer_async(lora_request)
|
| 491 |
+
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
| 492 |
+
|
| 493 |
+
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
| 494 |
+
prompt,
|
| 495 |
+
request_id=request_id,
|
| 496 |
+
lora_request=lora_request,
|
| 497 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 498 |
+
)
|
| 499 |
+
processed_inputs = self.input_processor(preprocessed_inputs)
|
| 500 |
+
|
| 501 |
+
if isinstance(params, SamplingParams) and \
|
| 502 |
+
params.guided_decoding is not None:
|
| 503 |
+
# Guided decoding has an async implementation for building logits
|
| 504 |
+
# processors in a separate threadpool.
|
| 505 |
+
# We want to invoke that here instead of using the blocking
|
| 506 |
+
# implementation in the LLMEngine
|
| 507 |
+
params = await build_guided_decoding_logits_processor_async(
|
| 508 |
+
sampling_params=params,
|
| 509 |
+
tokenizer=await self.get_tokenizer_async(lora_request),
|
| 510 |
+
default_guided_backend=self.decoding_config.
|
| 511 |
+
guided_decoding_backend,
|
| 512 |
+
model_config=self.model_config)
|
| 513 |
+
|
| 514 |
+
self._add_processed_request(
|
| 515 |
+
request_id=request_id,
|
| 516 |
+
processed_inputs=processed_inputs,
|
| 517 |
+
params=params,
|
| 518 |
+
arrival_time=arrival_time,
|
| 519 |
+
lora_request=lora_request,
|
| 520 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 521 |
+
trace_headers=trace_headers,
|
| 522 |
+
priority=priority,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
async def check_health_async(self) -> None:
|
| 526 |
+
if self.tokenizer:
|
| 527 |
+
self.tokenizer.check_health()
|
| 528 |
+
self.model_executor.check_health()
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
async def build_guided_decoding_logits_processor_async(
|
| 532 |
+
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
| 533 |
+
default_guided_backend: str,
|
| 534 |
+
model_config: ModelConfig) -> SamplingParams:
|
| 535 |
+
"""Constructs logits processors based on the guided_decoding,
|
| 536 |
+
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
| 537 |
+
those fields and adds the constructed logits processors to the
|
| 538 |
+
logits_processors field. Modifies sampling params in-place and returns
|
| 539 |
+
the modified sampling params."""
|
| 540 |
+
if sampling_params.guided_decoding is None:
|
| 541 |
+
return sampling_params
|
| 542 |
+
|
| 543 |
+
# Defensively copy sampling params since guided decoding logits
|
| 544 |
+
# processors can have different state for each request
|
| 545 |
+
sampling_params = copy.copy(sampling_params)
|
| 546 |
+
guided_decoding = sampling_params.guided_decoding
|
| 547 |
+
|
| 548 |
+
logger.debug("Building guided decoding logits processor. "
|
| 549 |
+
"Params: %s", guided_decoding)
|
| 550 |
+
|
| 551 |
+
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
| 552 |
+
|
| 553 |
+
processor = await get_guided_decoding_logits_processor(
|
| 554 |
+
guided_params=guided_decoding,
|
| 555 |
+
tokenizer=tokenizer,
|
| 556 |
+
model_config=model_config)
|
| 557 |
+
|
| 558 |
+
if processor:
|
| 559 |
+
if sampling_params.logits_processors is None:
|
| 560 |
+
sampling_params.logits_processors = []
|
| 561 |
+
sampling_params.logits_processors.append(processor)
|
| 562 |
+
|
| 563 |
+
# Unset guided decoding params after constructing the lp from them
|
| 564 |
+
sampling_params.guided_decoding = None
|
| 565 |
+
|
| 566 |
+
return sampling_params
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class AsyncLLMEngine(EngineClient):
|
| 570 |
+
"""An asynchronous wrapper for :class:`LLMEngine`.
|
| 571 |
+
|
| 572 |
+
This class is used to wrap the :class:`LLMEngine` class to make it
|
| 573 |
+
asynchronous. It uses asyncio to create a background loop that keeps
|
| 574 |
+
processing incoming requests. The :class:`LLMEngine` is kicked by the
|
| 575 |
+
generate method when there are requests in the waiting queue. The generate
|
| 576 |
+
method yields the outputs from the :class:`LLMEngine` to the caller.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
log_requests: Whether to log the requests.
|
| 580 |
+
start_engine_loop: If True, the background task to run the engine
|
| 581 |
+
will be automatically started in the generate call.
|
| 582 |
+
*args: Arguments for :class:`LLMEngine`.
|
| 583 |
+
**kwargs: Arguments for :class:`LLMEngine`.
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
| 587 |
+
|
| 588 |
+
def __init__(self,
|
| 589 |
+
*args,
|
| 590 |
+
log_requests: bool = True,
|
| 591 |
+
start_engine_loop: bool = True,
|
| 592 |
+
**kwargs) -> None:
|
| 593 |
+
self.log_requests = log_requests
|
| 594 |
+
self.engine = self._engine_class(*args, **kwargs)
|
| 595 |
+
|
| 596 |
+
# This ensures quick processing of request outputs
|
| 597 |
+
# so the append to asyncio queues is not delayed,
|
| 598 |
+
# especially for multi-step.
|
| 599 |
+
self.use_process_request_outputs_callback = (
|
| 600 |
+
self.engine.model_config.use_async_output_proc)
|
| 601 |
+
|
| 602 |
+
if self.use_process_request_outputs_callback:
|
| 603 |
+
self.engine.process_request_outputs_callback = \
|
| 604 |
+
weak_bind(self.process_request_outputs)
|
| 605 |
+
|
| 606 |
+
self.background_loop: Optional[asyncio.Future] = None
|
| 607 |
+
# We need to keep a reference to unshielded
|
| 608 |
+
# task as well to prevent it from being garbage
|
| 609 |
+
# collected
|
| 610 |
+
self._background_loop_unshielded: Optional[asyncio.Task] = None
|
| 611 |
+
self.start_engine_loop = start_engine_loop
|
| 612 |
+
self._errored_with: Optional[BaseException] = None
|
| 613 |
+
|
| 614 |
+
# Lazy initialized fields
|
| 615 |
+
self._request_tracker: RequestTracker
|
| 616 |
+
|
| 617 |
+
def __del__(self):
|
| 618 |
+
if rt := getattr(self, "request_tracker", None):
|
| 619 |
+
# Wake up engine loop so that it will exit cleanly
|
| 620 |
+
rt.new_requests_event.set()
|
| 621 |
+
|
| 622 |
+
@classmethod
|
| 623 |
+
def _get_executor_cls(cls,
|
| 624 |
+
engine_config: VllmConfig) -> Type[ExecutorBase]:
|
| 625 |
+
return LLMEngine._get_executor_cls(engine_config)
|
| 626 |
+
|
| 627 |
+
@classmethod
|
| 628 |
+
def from_engine_args(
|
| 629 |
+
cls,
|
| 630 |
+
engine_args: AsyncEngineArgs,
|
| 631 |
+
engine_config: Optional[VllmConfig] = None,
|
| 632 |
+
start_engine_loop: bool = True,
|
| 633 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
| 634 |
+
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
| 635 |
+
) -> "AsyncLLMEngine":
|
| 636 |
+
"""Creates an async LLM engine from the engine arguments."""
|
| 637 |
+
# Create the engine configs.
|
| 638 |
+
if engine_config is None:
|
| 639 |
+
engine_config = engine_args.create_engine_config(usage_context)
|
| 640 |
+
|
| 641 |
+
executor_class = cls._get_executor_cls(engine_config)
|
| 642 |
+
|
| 643 |
+
# Create the async LLM engine.
|
| 644 |
+
engine = cls(
|
| 645 |
+
vllm_config=engine_config,
|
| 646 |
+
executor_class=executor_class,
|
| 647 |
+
log_requests=not engine_args.disable_log_requests,
|
| 648 |
+
log_stats=not engine_args.disable_log_stats,
|
| 649 |
+
start_engine_loop=start_engine_loop,
|
| 650 |
+
usage_context=usage_context,
|
| 651 |
+
stat_loggers=stat_loggers,
|
| 652 |
+
)
|
| 653 |
+
return engine
|
| 654 |
+
|
| 655 |
+
@property
|
| 656 |
+
def is_running(self) -> bool:
|
| 657 |
+
return (self.background_loop is not None
|
| 658 |
+
and self._background_loop_unshielded is not None
|
| 659 |
+
and not self._background_loop_unshielded.done())
|
| 660 |
+
|
| 661 |
+
@property
|
| 662 |
+
def is_stopped(self) -> bool:
|
| 663 |
+
return self.errored or (self.background_loop is not None and
|
| 664 |
+
self._background_loop_unshielded is not None
|
| 665 |
+
and self._background_loop_unshielded.done())
|
| 666 |
+
|
| 667 |
+
@property
|
| 668 |
+
def errored(self) -> bool:
|
| 669 |
+
return self._errored_with is not None
|
| 670 |
+
|
| 671 |
+
@property
|
| 672 |
+
def dead_error(self) -> BaseException:
|
| 673 |
+
return AsyncEngineDeadError(
|
| 674 |
+
"Background loop is not running. If it was running, "
|
| 675 |
+
"inspect the output to find the stacktrace of the "
|
| 676 |
+
"error that caused the background loop to stop "
|
| 677 |
+
"(AsyncEngineDeadError).")
|
| 678 |
+
|
| 679 |
+
def set_errored(self, exc: Exception) -> None:
|
| 680 |
+
self._errored_with = exc
|
| 681 |
+
|
| 682 |
+
def _error_callback(self, exc: Exception) -> None:
|
| 683 |
+
self.set_errored(exc)
|
| 684 |
+
self._request_tracker.propagate_exception(exc)
|
| 685 |
+
|
| 686 |
+
async def get_input_preprocessor(self) -> InputPreprocessor:
|
| 687 |
+
return self.engine.input_preprocessor
|
| 688 |
+
|
| 689 |
+
async def get_tokenizer(
|
| 690 |
+
self,
|
| 691 |
+
lora_request: Optional[LoRARequest] = None,
|
| 692 |
+
) -> AnyTokenizer:
|
| 693 |
+
return await self.engine.get_tokenizer_async(lora_request)
|
| 694 |
+
|
| 695 |
+
def start_background_loop(self) -> None:
|
| 696 |
+
"""Start the background loop."""
|
| 697 |
+
if self.errored:
|
| 698 |
+
raise AsyncEngineDeadError(
|
| 699 |
+
"Background loop has errored already.") from self._errored_with
|
| 700 |
+
if self.is_running:
|
| 701 |
+
raise RuntimeError("Background loop is already running.")
|
| 702 |
+
# Initialize the RequestTracker here so it uses the right event loop.
|
| 703 |
+
self._request_tracker = RequestTracker()
|
| 704 |
+
|
| 705 |
+
self._background_loop_unshielded = asyncio.get_event_loop(
|
| 706 |
+
).create_task(self.run_engine_loop(weakref.ref(self)))
|
| 707 |
+
self._background_loop_unshielded.add_done_callback(
|
| 708 |
+
partial(_log_task_completion, error_callback=self._error_callback))
|
| 709 |
+
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
| 710 |
+
|
| 711 |
+
def shutdown_background_loop(self) -> None:
|
| 712 |
+
"""
|
| 713 |
+
Shut down the background loop.
|
| 714 |
+
|
| 715 |
+
This method needs to be called during cleanup to remove
|
| 716 |
+
references to `self` and properly GC the resources held
|
| 717 |
+
by the async LLM engine (e.g., the executors as well as
|
| 718 |
+
their resources).
|
| 719 |
+
"""
|
| 720 |
+
if self._background_loop_unshielded is not None:
|
| 721 |
+
self._background_loop_unshielded.cancel()
|
| 722 |
+
self._background_loop_unshielded = None
|
| 723 |
+
self.background_loop = None
|
| 724 |
+
|
| 725 |
+
async def engine_step(self, virtual_engine: int) -> bool:
|
| 726 |
+
"""Kick the engine to process the waiting requests.
|
| 727 |
+
|
| 728 |
+
Returns True if there are in-progress requests."""
|
| 729 |
+
|
| 730 |
+
new_requests, aborted_requests = (
|
| 731 |
+
self._request_tracker.get_new_and_aborted_requests())
|
| 732 |
+
|
| 733 |
+
for new_request in new_requests:
|
| 734 |
+
# Add the request into the vLLM engine's waiting queue.
|
| 735 |
+
try:
|
| 736 |
+
await self.engine.add_request_async(**new_request)
|
| 737 |
+
except ValueError as e:
|
| 738 |
+
# TODO: use a vLLM specific error for failed validation
|
| 739 |
+
self._request_tracker.process_exception(
|
| 740 |
+
new_request["request_id"],
|
| 741 |
+
e,
|
| 742 |
+
verbose=self.log_requests,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
if aborted_requests:
|
| 746 |
+
await self._engine_abort(aborted_requests)
|
| 747 |
+
|
| 748 |
+
request_outputs = await self.engine.step_async(virtual_engine)
|
| 749 |
+
|
| 750 |
+
# Put the outputs into the corresponding streams.
|
| 751 |
+
# If used as a callback, then already invoked inside
|
| 752 |
+
# LLMEngine's _process_model_outputs
|
| 753 |
+
if not self.use_process_request_outputs_callback:
|
| 754 |
+
all_finished = self.process_request_outputs(request_outputs)
|
| 755 |
+
else:
|
| 756 |
+
# For callback case, we only need to detect when all
|
| 757 |
+
# requests are finished
|
| 758 |
+
all_finished = all(request_output.finished
|
| 759 |
+
for request_output in request_outputs)
|
| 760 |
+
|
| 761 |
+
return not all_finished
|
| 762 |
+
|
| 763 |
+
def process_request_outputs(self, request_outputs) -> bool:
|
| 764 |
+
# Put the outputs into the corresponding streams.
|
| 765 |
+
all_finished = True
|
| 766 |
+
for request_output in request_outputs:
|
| 767 |
+
self._request_tracker.process_request_output(
|
| 768 |
+
request_output, verbose=self.log_requests)
|
| 769 |
+
all_finished = all_finished and request_output.finished
|
| 770 |
+
|
| 771 |
+
return all_finished
|
| 772 |
+
|
| 773 |
+
async def _engine_abort(self, request_ids: Iterable[str]):
|
| 774 |
+
self.engine.abort_request(request_ids)
|
| 775 |
+
|
| 776 |
+
@staticmethod
|
| 777 |
+
async def run_engine_loop(engine_ref: ReferenceType):
|
| 778 |
+
"""We use a weakref to the engine so that the running loop
|
| 779 |
+
doesn't prevent the engine being garbage collected."""
|
| 780 |
+
engine: Optional[AsyncLLMEngine] = engine_ref()
|
| 781 |
+
if not engine:
|
| 782 |
+
return
|
| 783 |
+
|
| 784 |
+
pipeline_parallel_size = \
|
| 785 |
+
engine.engine.parallel_config.pipeline_parallel_size
|
| 786 |
+
has_requests_in_progress = [False] * pipeline_parallel_size
|
| 787 |
+
while True:
|
| 788 |
+
if not any(has_requests_in_progress):
|
| 789 |
+
logger.debug("Waiting for new requests...")
|
| 790 |
+
# Stop the execute model loop in parallel workers until there
|
| 791 |
+
# are more requests to process. This avoids waiting
|
| 792 |
+
# indefinitely in torch.distributed ops which may otherwise
|
| 793 |
+
# timeout, and unblocks the RPC thread in the workers so that
|
| 794 |
+
# they can process any other queued control plane messages,
|
| 795 |
+
# such as add/remove lora adapters.
|
| 796 |
+
await engine.engine.stop_remote_worker_execution_loop_async()
|
| 797 |
+
request_tracker = engine._request_tracker
|
| 798 |
+
# Allow engine to be garbage collected while
|
| 799 |
+
# waiting for new requests
|
| 800 |
+
del engine
|
| 801 |
+
await asyncio.sleep(0)
|
| 802 |
+
if engine_ref() is None:
|
| 803 |
+
return
|
| 804 |
+
await request_tracker.wait_for_new_requests()
|
| 805 |
+
engine = engine_ref()
|
| 806 |
+
if not engine:
|
| 807 |
+
return
|
| 808 |
+
logger.debug("Got new requests!")
|
| 809 |
+
requests_in_progress = [
|
| 810 |
+
asyncio.create_task(engine.engine_step(ve))
|
| 811 |
+
for ve in range(pipeline_parallel_size)
|
| 812 |
+
]
|
| 813 |
+
has_requests_in_progress = [True] * pipeline_parallel_size
|
| 814 |
+
|
| 815 |
+
# Abort if iteration takes too long due to unrecoverable errors
|
| 816 |
+
# (eg. NCCL timeouts).
|
| 817 |
+
try:
|
| 818 |
+
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
|
| 819 |
+
done, _ = await asyncio.wait(
|
| 820 |
+
requests_in_progress,
|
| 821 |
+
return_when=asyncio.FIRST_COMPLETED)
|
| 822 |
+
for _ in range(pipeline_parallel_size):
|
| 823 |
+
await asyncio.sleep(0)
|
| 824 |
+
for task in done:
|
| 825 |
+
result = task.result()
|
| 826 |
+
virtual_engine = requests_in_progress.index(task)
|
| 827 |
+
has_unfinished_requests = (
|
| 828 |
+
engine.engine.
|
| 829 |
+
has_unfinished_requests_for_virtual_engine(
|
| 830 |
+
virtual_engine))
|
| 831 |
+
if result or has_unfinished_requests:
|
| 832 |
+
requests_in_progress[virtual_engine] = (
|
| 833 |
+
asyncio.create_task(
|
| 834 |
+
engine.engine_step(virtual_engine)))
|
| 835 |
+
has_requests_in_progress[virtual_engine] = True
|
| 836 |
+
else:
|
| 837 |
+
has_requests_in_progress[virtual_engine] = False
|
| 838 |
+
except asyncio.TimeoutError as exc:
|
| 839 |
+
logger.error(
|
| 840 |
+
"Engine iteration timed out. This should never happen!")
|
| 841 |
+
engine.set_errored(exc)
|
| 842 |
+
raise
|
| 843 |
+
await asyncio.sleep(0)
|
| 844 |
+
|
| 845 |
+
# This method does not need to be async, but kept that way
|
| 846 |
+
# for backwards compatibility.
|
| 847 |
+
@overload
|
| 848 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 849 |
+
def add_request(
|
| 850 |
+
self,
|
| 851 |
+
request_id: str,
|
| 852 |
+
*,
|
| 853 |
+
inputs: PromptType,
|
| 854 |
+
params: Union[SamplingParams, PoolingParams],
|
| 855 |
+
arrival_time: Optional[float] = None,
|
| 856 |
+
lora_request: Optional[LoRARequest] = None,
|
| 857 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 858 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 859 |
+
priority: int = 0,
|
| 860 |
+
) -> Coroutine[None, None, AsyncGenerator[Union[
|
| 861 |
+
RequestOutput, PoolingRequestOutput], None]]:
|
| 862 |
+
...
|
| 863 |
+
|
| 864 |
+
@overload
|
| 865 |
+
def add_request(
|
| 866 |
+
self,
|
| 867 |
+
request_id: str,
|
| 868 |
+
prompt: PromptType,
|
| 869 |
+
params: Union[SamplingParams, PoolingParams],
|
| 870 |
+
arrival_time: Optional[float] = None,
|
| 871 |
+
lora_request: Optional[LoRARequest] = None,
|
| 872 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 873 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 874 |
+
priority: int = 0,
|
| 875 |
+
) -> Coroutine[None, None, AsyncGenerator[Union[
|
| 876 |
+
RequestOutput, PoolingRequestOutput], None]]:
|
| 877 |
+
...
|
| 878 |
+
|
| 879 |
+
@deprecate_kwargs(
|
| 880 |
+
"inputs",
|
| 881 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 882 |
+
)
|
| 883 |
+
async def add_request(
|
| 884 |
+
self,
|
| 885 |
+
request_id: str,
|
| 886 |
+
prompt: Optional[PromptType] = None,
|
| 887 |
+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
| 888 |
+
arrival_time: Optional[float] = None,
|
| 889 |
+
lora_request: Optional[LoRARequest] = None,
|
| 890 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 891 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 892 |
+
priority: int = 0,
|
| 893 |
+
*,
|
| 894 |
+
inputs: Optional[PromptType] = None, # DEPRECATED
|
| 895 |
+
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
| 896 |
+
if inputs is not None:
|
| 897 |
+
prompt = inputs
|
| 898 |
+
assert prompt is not None and params is not None
|
| 899 |
+
|
| 900 |
+
if not self.is_running:
|
| 901 |
+
if self.start_engine_loop:
|
| 902 |
+
self.start_background_loop()
|
| 903 |
+
else:
|
| 904 |
+
raise AsyncEngineDeadError(
|
| 905 |
+
"Background loop is not running. If it was running, "
|
| 906 |
+
"inspect the output to find the stacktrace of the "
|
| 907 |
+
"error that caused the background loop to stop "
|
| 908 |
+
"(AsyncEngineDeadError).")
|
| 909 |
+
|
| 910 |
+
if (priority != 0
|
| 911 |
+
and not self.engine.scheduler_config.policy == "priority"):
|
| 912 |
+
raise ValueError(f"Got priority {priority} but "
|
| 913 |
+
"Priority scheduling is not enabled.")
|
| 914 |
+
|
| 915 |
+
stream = self._request_tracker.add_request(
|
| 916 |
+
request_id,
|
| 917 |
+
verbose=self.log_requests,
|
| 918 |
+
prompt=prompt,
|
| 919 |
+
params=params,
|
| 920 |
+
arrival_time=arrival_time or time.time(),
|
| 921 |
+
lora_request=lora_request,
|
| 922 |
+
trace_headers=trace_headers,
|
| 923 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 924 |
+
priority=priority,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
return stream.generator()
|
| 928 |
+
|
| 929 |
+
async def generate(
|
| 930 |
+
self,
|
| 931 |
+
prompt: PromptType,
|
| 932 |
+
sampling_params: SamplingParams,
|
| 933 |
+
request_id: str,
|
| 934 |
+
lora_request: Optional[LoRARequest] = None,
|
| 935 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 936 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 937 |
+
priority: int = 0,
|
| 938 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 939 |
+
"""Generate outputs for a request.
|
| 940 |
+
|
| 941 |
+
Generate outputs for a request. This method is a coroutine. It adds the
|
| 942 |
+
request into the waiting queue of the LLMEngine and streams the outputs
|
| 943 |
+
from the LLMEngine to the caller.
|
| 944 |
+
|
| 945 |
+
Args:
|
| 946 |
+
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
| 947 |
+
for more details about the format of each input.
|
| 948 |
+
sampling_params: The sampling parameters of the request.
|
| 949 |
+
request_id: The unique id of the request.
|
| 950 |
+
lora_request: LoRA request to use for generation, if any.
|
| 951 |
+
trace_headers: OpenTelemetry trace headers.
|
| 952 |
+
prompt_adapter_request: Prompt Adapter request to use
|
| 953 |
+
for generation, if any.
|
| 954 |
+
priority: The priority of the request.
|
| 955 |
+
Only applicable with priority scheduling.
|
| 956 |
+
|
| 957 |
+
Yields:
|
| 958 |
+
The output `RequestOutput` objects from the LLMEngine
|
| 959 |
+
for the request.
|
| 960 |
+
|
| 961 |
+
Details:
|
| 962 |
+
- If the engine is not running, start the background loop,
|
| 963 |
+
which iteratively invokes
|
| 964 |
+
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
| 965 |
+
to process the waiting requests.
|
| 966 |
+
- Add the request to the engine's `RequestTracker`.
|
| 967 |
+
On the next background loop, this request will be sent to
|
| 968 |
+
the underlying engine.
|
| 969 |
+
Also, a corresponding `AsyncStream` will be created.
|
| 970 |
+
- Wait for the request outputs from `AsyncStream` and yield them.
|
| 971 |
+
|
| 972 |
+
Example:
|
| 973 |
+
>>> # Please refer to entrypoints/api_server.py for
|
| 974 |
+
>>> # the complete example.
|
| 975 |
+
>>>
|
| 976 |
+
>>> # initialize the engine and the example input
|
| 977 |
+
>>> # note that engine_args here is AsyncEngineArgs instance
|
| 978 |
+
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 979 |
+
>>> example_input = {
|
| 980 |
+
>>> "prompt": "What is LLM?",
|
| 981 |
+
>>> "stream": False, # assume the non-streaming case
|
| 982 |
+
>>> "temperature": 0.0,
|
| 983 |
+
>>> "request_id": 0,
|
| 984 |
+
>>> }
|
| 985 |
+
>>>
|
| 986 |
+
>>> # start the generation
|
| 987 |
+
>>> results_generator = engine.generate(
|
| 988 |
+
>>> example_input["prompt"],
|
| 989 |
+
>>> SamplingParams(temperature=example_input["temperature"]),
|
| 990 |
+
>>> example_input["request_id"])
|
| 991 |
+
>>>
|
| 992 |
+
>>> # get the results
|
| 993 |
+
>>> final_output = None
|
| 994 |
+
>>> async for request_output in results_generator:
|
| 995 |
+
>>> if await request.is_disconnected():
|
| 996 |
+
>>> # Abort the request if the client disconnects.
|
| 997 |
+
>>> await engine.abort(request_id)
|
| 998 |
+
>>> # Return or raise an error
|
| 999 |
+
>>> ...
|
| 1000 |
+
>>> final_output = request_output
|
| 1001 |
+
>>>
|
| 1002 |
+
>>> # Process and return the final output
|
| 1003 |
+
>>> ...
|
| 1004 |
+
"""
|
| 1005 |
+
try:
|
| 1006 |
+
async for output in await self.add_request(
|
| 1007 |
+
request_id,
|
| 1008 |
+
prompt,
|
| 1009 |
+
sampling_params,
|
| 1010 |
+
lora_request=lora_request,
|
| 1011 |
+
trace_headers=trace_headers,
|
| 1012 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 1013 |
+
priority=priority,
|
| 1014 |
+
):
|
| 1015 |
+
yield LLMEngine.validate_output(output, RequestOutput)
|
| 1016 |
+
except asyncio.CancelledError:
|
| 1017 |
+
await self.abort(request_id)
|
| 1018 |
+
raise
|
| 1019 |
+
|
| 1020 |
+
async def encode(
|
| 1021 |
+
self,
|
| 1022 |
+
prompt: PromptType,
|
| 1023 |
+
pooling_params: PoolingParams,
|
| 1024 |
+
request_id: str,
|
| 1025 |
+
lora_request: Optional[LoRARequest] = None,
|
| 1026 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 1027 |
+
priority: int = 0,
|
| 1028 |
+
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
| 1029 |
+
"""Generate outputs for a request from a pooling model.
|
| 1030 |
+
|
| 1031 |
+
Generate outputs for a request. This method is a coroutine. It adds the
|
| 1032 |
+
request into the waiting queue of the LLMEngine and streams the outputs
|
| 1033 |
+
from the LLMEngine to the caller.
|
| 1034 |
+
|
| 1035 |
+
Args:
|
| 1036 |
+
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
| 1037 |
+
for more details about the format of each input.
|
| 1038 |
+
pooling_params: The pooling parameters of the request.
|
| 1039 |
+
request_id: The unique id of the request.
|
| 1040 |
+
lora_request: LoRA request to use for generation, if any.
|
| 1041 |
+
trace_headers: OpenTelemetry trace headers.
|
| 1042 |
+
priority: The priority of the request.
|
| 1043 |
+
Only applicable with priority scheduling.
|
| 1044 |
+
|
| 1045 |
+
Yields:
|
| 1046 |
+
The output `PoolingRequestOutput` objects from the LLMEngine
|
| 1047 |
+
for the request.
|
| 1048 |
+
|
| 1049 |
+
Details:
|
| 1050 |
+
- If the engine is not running, start the background loop,
|
| 1051 |
+
which iteratively invokes
|
| 1052 |
+
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
| 1053 |
+
to process the waiting requests.
|
| 1054 |
+
- Add the request to the engine's `RequestTracker`.
|
| 1055 |
+
On the next background loop, this request will be sent to
|
| 1056 |
+
the underlying engine.
|
| 1057 |
+
Also, a corresponding `AsyncStream` will be created.
|
| 1058 |
+
- Wait for the request outputs from `AsyncStream` and yield them.
|
| 1059 |
+
|
| 1060 |
+
Example:
|
| 1061 |
+
>>> # Please refer to entrypoints/api_server.py for
|
| 1062 |
+
>>> # the complete example.
|
| 1063 |
+
>>>
|
| 1064 |
+
>>> # initialize the engine and the example input
|
| 1065 |
+
>>> # note that engine_args here is AsyncEngineArgs instance
|
| 1066 |
+
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 1067 |
+
>>> example_input = {
|
| 1068 |
+
>>> "input": "What is LLM?",
|
| 1069 |
+
>>> "request_id": 0,
|
| 1070 |
+
>>> }
|
| 1071 |
+
>>>
|
| 1072 |
+
>>> # start the generation
|
| 1073 |
+
>>> results_generator = engine.encode(
|
| 1074 |
+
>>> example_input["input"],
|
| 1075 |
+
>>> PoolingParams(),
|
| 1076 |
+
>>> example_input["request_id"])
|
| 1077 |
+
>>>
|
| 1078 |
+
>>> # get the results
|
| 1079 |
+
>>> final_output = None
|
| 1080 |
+
>>> async for request_output in results_generator:
|
| 1081 |
+
>>> if await request.is_disconnected():
|
| 1082 |
+
>>> # Abort the request if the client disconnects.
|
| 1083 |
+
>>> await engine.abort(request_id)
|
| 1084 |
+
>>> # Return or raise an error
|
| 1085 |
+
>>> ...
|
| 1086 |
+
>>> final_output = request_output
|
| 1087 |
+
>>>
|
| 1088 |
+
>>> # Process and return the final output
|
| 1089 |
+
>>> ...
|
| 1090 |
+
"""
|
| 1091 |
+
try:
|
| 1092 |
+
async for output in await self.add_request(
|
| 1093 |
+
request_id,
|
| 1094 |
+
prompt,
|
| 1095 |
+
pooling_params,
|
| 1096 |
+
lora_request=lora_request,
|
| 1097 |
+
trace_headers=trace_headers,
|
| 1098 |
+
priority=priority,
|
| 1099 |
+
):
|
| 1100 |
+
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
| 1101 |
+
except asyncio.CancelledError:
|
| 1102 |
+
await self.abort(request_id)
|
| 1103 |
+
raise
|
| 1104 |
+
|
| 1105 |
+
async def abort(self, request_id: str) -> None:
|
| 1106 |
+
"""Abort a request.
|
| 1107 |
+
|
| 1108 |
+
Abort a submitted request. If the request is finished or not found,
|
| 1109 |
+
this method will be a no-op.
|
| 1110 |
+
|
| 1111 |
+
Args:
|
| 1112 |
+
request_id: The unique id of the request.
|
| 1113 |
+
"""
|
| 1114 |
+
if not self.is_running:
|
| 1115 |
+
raise AsyncEngineDeadError(
|
| 1116 |
+
"Background loop is not running. If it was running, "
|
| 1117 |
+
"inspect the output to find the stacktrace of the "
|
| 1118 |
+
"error that caused the background loop to stop "
|
| 1119 |
+
"(AsyncEngineDeadError).")
|
| 1120 |
+
|
| 1121 |
+
return self._abort(request_id)
|
| 1122 |
+
|
| 1123 |
+
def _abort(self, request_id: str) -> None:
|
| 1124 |
+
"""Abort a request.
|
| 1125 |
+
|
| 1126 |
+
Abort a submitted request. If the request is finished or not found,
|
| 1127 |
+
this method will be a no-op.
|
| 1128 |
+
|
| 1129 |
+
Args:
|
| 1130 |
+
request_id: The unique id of the request.
|
| 1131 |
+
"""
|
| 1132 |
+
self._request_tracker.abort_request(request_id,
|
| 1133 |
+
exception=asyncio.CancelledError,
|
| 1134 |
+
verbose=self.log_requests)
|
| 1135 |
+
|
| 1136 |
+
async def get_model_config(self) -> ModelConfig:
|
| 1137 |
+
"""Get the model configuration of the vLLM engine."""
|
| 1138 |
+
return self.engine.get_model_config()
|
| 1139 |
+
|
| 1140 |
+
async def get_parallel_config(self) -> ParallelConfig:
|
| 1141 |
+
"""Get the parallel configuration of the vLLM engine."""
|
| 1142 |
+
return self.engine.get_parallel_config()
|
| 1143 |
+
|
| 1144 |
+
async def get_decoding_config(self) -> DecodingConfig:
|
| 1145 |
+
"""Get the decoding configuration of the vLLM engine."""
|
| 1146 |
+
return self.engine.get_decoding_config()
|
| 1147 |
+
|
| 1148 |
+
async def get_scheduler_config(self) -> SchedulerConfig:
|
| 1149 |
+
"""Get the scheduling configuration of the vLLM engine."""
|
| 1150 |
+
return self.engine.get_scheduler_config()
|
| 1151 |
+
|
| 1152 |
+
async def get_lora_config(self) -> LoRAConfig:
|
| 1153 |
+
"""Get the lora configuration of the vLLM engine."""
|
| 1154 |
+
return self.engine.get_lora_config()
|
| 1155 |
+
|
| 1156 |
+
async def do_log_stats(
|
| 1157 |
+
self,
|
| 1158 |
+
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
| 1159 |
+
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
| 1160 |
+
self.engine.do_log_stats()
|
| 1161 |
+
|
| 1162 |
+
async def check_health(self) -> None:
|
| 1163 |
+
"""Raises an error if engine is unhealthy."""
|
| 1164 |
+
t = time.perf_counter()
|
| 1165 |
+
logger.debug("Starting health check...")
|
| 1166 |
+
if self.is_stopped:
|
| 1167 |
+
raise AsyncEngineDeadError("Background loop is stopped.")
|
| 1168 |
+
|
| 1169 |
+
await self.engine.check_health_async()
|
| 1170 |
+
logger.debug("Health check took %fs", time.perf_counter() - t)
|
| 1171 |
+
|
| 1172 |
+
async def is_tracing_enabled(self) -> bool:
|
| 1173 |
+
return self.engine.is_tracing_enabled()
|
| 1174 |
+
|
| 1175 |
+
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
| 1176 |
+
self.engine.add_logger(logger_name=logger_name, logger=logger)
|
| 1177 |
+
|
| 1178 |
+
def remove_logger(self, logger_name: str) -> None:
|
| 1179 |
+
self.engine.remove_logger(logger_name=logger_name)
|
| 1180 |
+
|
| 1181 |
+
async def start_profile(self) -> None:
|
| 1182 |
+
self.engine.start_profile()
|
| 1183 |
+
|
| 1184 |
+
async def stop_profile(self) -> None:
|
| 1185 |
+
self.engine.stop_profile()
|
| 1186 |
+
|
| 1187 |
+
async def reset_prefix_cache(self) -> None:
|
| 1188 |
+
self.engine.reset_prefix_cache()
|
| 1189 |
+
|
| 1190 |
+
async def add_lora(self, lora_request: LoRARequest) -> None:
|
| 1191 |
+
self.engine.add_lora(lora_request)
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
# TODO(v1): Remove this class proxy when V1 goes default.
|
| 1195 |
+
if envs.VLLM_USE_V1:
|
| 1196 |
+
from vllm.v1.engine.async_llm import AsyncLLM
|
| 1197 |
+
|
| 1198 |
+
AsyncLLMEngine = AsyncLLM # type: ignore
|
.venv/lib/python3.11/site-packages/vllm/engine/async_timeout.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Workaround for https://github.com/python/cpython/issues/86296
|
| 4 |
+
#
|
| 5 |
+
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
| 6 |
+
# Licensed under the Apache License (Apache-2.0)
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import enum
|
| 10 |
+
import sys
|
| 11 |
+
import warnings
|
| 12 |
+
from types import TracebackType
|
| 13 |
+
from typing import Any, Optional, Type
|
| 14 |
+
|
| 15 |
+
if sys.version_info[:2] >= (3, 11):
|
| 16 |
+
from asyncio import timeout as asyncio_timeout
|
| 17 |
+
else:
|
| 18 |
+
|
| 19 |
+
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
|
| 20 |
+
"""timeout context manager.
|
| 21 |
+
Useful in cases when you want to apply timeout logic around block
|
| 22 |
+
of code or in cases when asyncio.wait_for is not suitable. For example:
|
| 23 |
+
>>> async with timeout(0.001):
|
| 24 |
+
... async with aiohttp.get('https://github.com') as r:
|
| 25 |
+
... await r.text()
|
| 26 |
+
delay - value in seconds or None to disable timeout logic
|
| 27 |
+
"""
|
| 28 |
+
loop = asyncio.get_running_loop()
|
| 29 |
+
deadline = loop.time() + delay if delay is not None else None
|
| 30 |
+
return Timeout(deadline, loop)
|
| 31 |
+
|
| 32 |
+
class _State(enum.Enum):
|
| 33 |
+
INIT = "INIT"
|
| 34 |
+
ENTER = "ENTER"
|
| 35 |
+
TIMEOUT = "TIMEOUT"
|
| 36 |
+
EXIT = "EXIT"
|
| 37 |
+
|
| 38 |
+
class Timeout:
|
| 39 |
+
# Internal class, please don't instantiate it directly
|
| 40 |
+
# Use timeout() and timeout_at() public factories instead.
|
| 41 |
+
#
|
| 42 |
+
# Implementation note: `async with timeout()` is preferred
|
| 43 |
+
# over `with timeout()`.
|
| 44 |
+
# While technically the Timeout class implementation
|
| 45 |
+
# doesn't need to be async at all,
|
| 46 |
+
# the `async with` statement explicitly points that
|
| 47 |
+
# the context manager should be used from async function context.
|
| 48 |
+
#
|
| 49 |
+
# This design allows to avoid many silly misusages.
|
| 50 |
+
#
|
| 51 |
+
# TimeoutError is raised immediately when scheduled
|
| 52 |
+
# if the deadline is passed.
|
| 53 |
+
# The purpose is to time out as soon as possible
|
| 54 |
+
# without waiting for the next await expression.
|
| 55 |
+
|
| 56 |
+
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
|
| 57 |
+
|
| 58 |
+
def __init__(self, deadline: Optional[float],
|
| 59 |
+
loop: asyncio.AbstractEventLoop) -> None:
|
| 60 |
+
self._loop = loop
|
| 61 |
+
self._state = _State.INIT
|
| 62 |
+
|
| 63 |
+
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
| 64 |
+
if deadline is None:
|
| 65 |
+
self._deadline = None # type: Optional[float]
|
| 66 |
+
else:
|
| 67 |
+
self.update(deadline)
|
| 68 |
+
|
| 69 |
+
def __enter__(self) -> "Timeout":
|
| 70 |
+
warnings.warn(
|
| 71 |
+
"with timeout() is deprecated, use async with timeout()",
|
| 72 |
+
DeprecationWarning,
|
| 73 |
+
stacklevel=2,
|
| 74 |
+
)
|
| 75 |
+
self._do_enter()
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
def __exit__(
|
| 79 |
+
self,
|
| 80 |
+
exc_type: Optional[Type[BaseException]],
|
| 81 |
+
exc_val: Optional[BaseException],
|
| 82 |
+
exc_tb: Optional[TracebackType],
|
| 83 |
+
) -> Optional[bool]:
|
| 84 |
+
self._do_exit(exc_type)
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
async def __aenter__(self) -> "Timeout":
|
| 88 |
+
self._do_enter()
|
| 89 |
+
return self
|
| 90 |
+
|
| 91 |
+
async def __aexit__(
|
| 92 |
+
self,
|
| 93 |
+
exc_type: Optional[Type[BaseException]],
|
| 94 |
+
exc_val: Optional[BaseException],
|
| 95 |
+
exc_tb: Optional[TracebackType],
|
| 96 |
+
) -> Optional[bool]:
|
| 97 |
+
self._do_exit(exc_type)
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def expired(self) -> bool:
|
| 102 |
+
"""Is timeout expired during execution?"""
|
| 103 |
+
return self._state == _State.TIMEOUT
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def deadline(self) -> Optional[float]:
|
| 107 |
+
return self._deadline
|
| 108 |
+
|
| 109 |
+
def reject(self) -> None:
|
| 110 |
+
"""Reject scheduled timeout if any."""
|
| 111 |
+
# cancel is maybe better name but
|
| 112 |
+
# task.cancel() raises CancelledError in asyncio world.
|
| 113 |
+
if self._state not in (_State.INIT, _State.ENTER):
|
| 114 |
+
raise RuntimeError(f"invalid state {self._state.value}")
|
| 115 |
+
self._reject()
|
| 116 |
+
|
| 117 |
+
def _reject(self) -> None:
|
| 118 |
+
if self._timeout_handler is not None:
|
| 119 |
+
self._timeout_handler.cancel()
|
| 120 |
+
self._timeout_handler = None
|
| 121 |
+
|
| 122 |
+
def shift(self, delay: float) -> None:
|
| 123 |
+
"""Advance timeout on delay seconds.
|
| 124 |
+
The delay can be negative.
|
| 125 |
+
Raise RuntimeError if shift is called when deadline is not scheduled
|
| 126 |
+
"""
|
| 127 |
+
deadline = self._deadline
|
| 128 |
+
if deadline is None:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
"cannot shift timeout if deadline is not scheduled")
|
| 131 |
+
self.update(deadline + delay)
|
| 132 |
+
|
| 133 |
+
def update(self, deadline: float) -> None:
|
| 134 |
+
"""Set deadline to absolute value.
|
| 135 |
+
deadline argument points on the time in the same clock system
|
| 136 |
+
as loop.time().
|
| 137 |
+
If new deadline is in the past the timeout is raised immediately.
|
| 138 |
+
Please note: it is not POSIX time but a time with
|
| 139 |
+
undefined starting base, e.g. the time of the system power on.
|
| 140 |
+
"""
|
| 141 |
+
if self._state == _State.EXIT:
|
| 142 |
+
raise RuntimeError(
|
| 143 |
+
"cannot reschedule after exit from context manager")
|
| 144 |
+
if self._state == _State.TIMEOUT:
|
| 145 |
+
raise RuntimeError("cannot reschedule expired timeout")
|
| 146 |
+
if self._timeout_handler is not None:
|
| 147 |
+
self._timeout_handler.cancel()
|
| 148 |
+
self._deadline = deadline
|
| 149 |
+
if self._state != _State.INIT:
|
| 150 |
+
self._reschedule()
|
| 151 |
+
|
| 152 |
+
def _reschedule(self) -> None:
|
| 153 |
+
assert self._state == _State.ENTER
|
| 154 |
+
deadline = self._deadline
|
| 155 |
+
if deadline is None:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
now = self._loop.time()
|
| 159 |
+
if self._timeout_handler is not None:
|
| 160 |
+
self._timeout_handler.cancel()
|
| 161 |
+
|
| 162 |
+
task = asyncio.current_task()
|
| 163 |
+
if deadline <= now:
|
| 164 |
+
self._timeout_handler = self._loop.call_soon(
|
| 165 |
+
self._on_timeout, task)
|
| 166 |
+
else:
|
| 167 |
+
self._timeout_handler = self._loop.call_at(
|
| 168 |
+
deadline, self._on_timeout, task)
|
| 169 |
+
|
| 170 |
+
def _do_enter(self) -> None:
|
| 171 |
+
if self._state != _State.INIT:
|
| 172 |
+
raise RuntimeError(f"invalid state {self._state.value}")
|
| 173 |
+
self._state = _State.ENTER
|
| 174 |
+
self._reschedule()
|
| 175 |
+
|
| 176 |
+
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
| 177 |
+
if exc_type is asyncio.CancelledError and \
|
| 178 |
+
self._state == _State.TIMEOUT:
|
| 179 |
+
self._timeout_handler = None
|
| 180 |
+
raise asyncio.TimeoutError
|
| 181 |
+
# timeout has not expired
|
| 182 |
+
self._state = _State.EXIT
|
| 183 |
+
self._reject()
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
|
| 187 |
+
if task:
|
| 188 |
+
task.cancel()
|
| 189 |
+
self._state = _State.TIMEOUT
|
| 190 |
+
# drop the reference early
|
| 191 |
+
self._timeout_handler = None
|
.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py
ADDED
|
@@ -0,0 +1,2025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import time
|
| 5 |
+
from collections import Counter as collectionsCounter
|
| 6 |
+
from collections import deque
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
| 11 |
+
List, Mapping, NamedTuple, Optional)
|
| 12 |
+
from typing import Sequence as GenericSequence
|
| 13 |
+
from typing import Set, Type, Union, cast, overload
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from typing_extensions import TypeVar, deprecated
|
| 17 |
+
|
| 18 |
+
import vllm.envs as envs
|
| 19 |
+
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
| 20 |
+
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
| 21 |
+
VllmConfig)
|
| 22 |
+
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
| 23 |
+
SchedulerOutputs)
|
| 24 |
+
from vllm.engine.arg_utils import EngineArgs
|
| 25 |
+
from vllm.engine.metrics_types import StatLoggerBase, Stats
|
| 26 |
+
from vllm.engine.output_processor.interfaces import (
|
| 27 |
+
SequenceGroupOutputProcessor)
|
| 28 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
| 29 |
+
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
| 30 |
+
from vllm.entrypoints.openai.logits_processors import (
|
| 31 |
+
get_logits_processors as get_openai_logits_processors)
|
| 32 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 33 |
+
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
| 34 |
+
PromptType, SingletonInputsAdapter)
|
| 35 |
+
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
| 36 |
+
from vllm.inputs.preprocess import InputPreprocessor
|
| 37 |
+
from vllm.logger import init_logger
|
| 38 |
+
from vllm.logits_process import get_bad_words_logits_processors
|
| 39 |
+
from vllm.lora.request import LoRARequest
|
| 40 |
+
from vllm.model_executor.guided_decoding import (
|
| 41 |
+
get_local_guided_decoding_logits_processor)
|
| 42 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 43 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
| 44 |
+
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
| 45 |
+
RequestOutputFactory)
|
| 46 |
+
from vllm.pooling_params import PoolingParams
|
| 47 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 48 |
+
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
| 49 |
+
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
| 50 |
+
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
|
| 51 |
+
SequenceGroupBase, SequenceGroupMetadata,
|
| 52 |
+
SequenceGroupOutput, SequenceStatus)
|
| 53 |
+
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
| 54 |
+
init_tracer)
|
| 55 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
| 56 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 57 |
+
from vllm.transformers_utils.tokenizer_group import (
|
| 58 |
+
BaseTokenizerGroup, init_tokenizer_from_configs)
|
| 59 |
+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
| 60 |
+
usage_message)
|
| 61 |
+
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
|
| 62 |
+
from vllm.version import __version__ as VLLM_VERSION
|
| 63 |
+
|
| 64 |
+
logger = init_logger(__name__)
|
| 65 |
+
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
| 66 |
+
|
| 67 |
+
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
| 68 |
+
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class SchedulerOutputState:
|
| 73 |
+
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
| 74 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
| 75 |
+
scheduler_outputs: Optional[SchedulerOutputs] = None
|
| 76 |
+
allow_async_output_proc: bool = False
|
| 77 |
+
last_output: Optional[SamplerOutput] = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class OutputData(NamedTuple):
|
| 81 |
+
outputs: List[SamplerOutput]
|
| 82 |
+
seq_group_metadata_list: List[SequenceGroupMetadata]
|
| 83 |
+
scheduler_outputs: SchedulerOutputs
|
| 84 |
+
is_async: bool
|
| 85 |
+
is_last_step: bool
|
| 86 |
+
# Indicates if this output is from the first step of the
|
| 87 |
+
# multi-step. When multi-step is disabled, this is always
|
| 88 |
+
# set to True.
|
| 89 |
+
# is_first_step_output is invalid when `outputs` has
|
| 90 |
+
# outputs from multiple steps.
|
| 91 |
+
is_first_step_output: Optional[bool]
|
| 92 |
+
skip: List[int]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SchedulerContext:
|
| 96 |
+
|
| 97 |
+
def __init__(self, multi_step_stream_outputs: bool = False):
|
| 98 |
+
self.output_queue: Deque[OutputData] = deque()
|
| 99 |
+
self.request_outputs: List[Union[RequestOutput,
|
| 100 |
+
PoolingRequestOutput]] = []
|
| 101 |
+
self.seq_group_metadata_list: Optional[
|
| 102 |
+
List[SequenceGroupMetadata]] = None
|
| 103 |
+
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
| 104 |
+
|
| 105 |
+
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
|
| 106 |
+
|
| 107 |
+
def append_output(self, outputs: List[SamplerOutput],
|
| 108 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 109 |
+
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
| 110 |
+
is_last_step: bool,
|
| 111 |
+
is_first_step_output: Optional[bool]):
|
| 112 |
+
self.output_queue.append(
|
| 113 |
+
OutputData(outputs=outputs,
|
| 114 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 115 |
+
scheduler_outputs=scheduler_outputs,
|
| 116 |
+
is_async=is_async,
|
| 117 |
+
is_last_step=is_last_step,
|
| 118 |
+
is_first_step_output=is_first_step_output,
|
| 119 |
+
skip=[]))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class LLMEngine:
|
| 123 |
+
"""An LLM engine that receives requests and generates texts.
|
| 124 |
+
|
| 125 |
+
This is the main class for the vLLM engine. It receives requests
|
| 126 |
+
from clients and generates texts from the LLM. It includes a tokenizer, a
|
| 127 |
+
language model (possibly distributed across multiple GPUs), and GPU memory
|
| 128 |
+
space allocated for intermediate states (aka KV cache). This class utilizes
|
| 129 |
+
iteration-level scheduling and efficient memory management to maximize the
|
| 130 |
+
serving throughput.
|
| 131 |
+
|
| 132 |
+
The :class:`~vllm.LLM` class wraps this class for offline batched inference
|
| 133 |
+
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
|
| 134 |
+
|
| 135 |
+
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
|
| 136 |
+
:ref:`engine-args`)
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model_config: The configuration related to the LLM model.
|
| 140 |
+
cache_config: The configuration related to the KV cache memory
|
| 141 |
+
management.
|
| 142 |
+
parallel_config: The configuration related to distributed execution.
|
| 143 |
+
scheduler_config: The configuration related to the request scheduler.
|
| 144 |
+
device_config: The configuration related to the device.
|
| 145 |
+
lora_config (Optional): The configuration related to serving multi-LoRA.
|
| 146 |
+
speculative_config (Optional): The configuration related to speculative
|
| 147 |
+
decoding.
|
| 148 |
+
executor_class: The model executor class for managing distributed
|
| 149 |
+
execution.
|
| 150 |
+
prompt_adapter_config (Optional): The configuration related to serving
|
| 151 |
+
prompt adapters.
|
| 152 |
+
log_stats: Whether to log statistics.
|
| 153 |
+
usage_context: Specified entry point, used for usage info collection.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
|
| 157 |
+
"""A flag to toggle whether to validate the type of request output."""
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
@contextmanager
|
| 161 |
+
def enable_output_validation(cls):
|
| 162 |
+
cls.DO_VALIDATE_OUTPUT = True
|
| 163 |
+
|
| 164 |
+
yield
|
| 165 |
+
|
| 166 |
+
cls.DO_VALIDATE_OUTPUT = False
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def validate_output(
|
| 170 |
+
cls,
|
| 171 |
+
output: object,
|
| 172 |
+
output_type: Type[_O],
|
| 173 |
+
) -> _O:
|
| 174 |
+
do_validate = cls.DO_VALIDATE_OUTPUT
|
| 175 |
+
|
| 176 |
+
if ((TYPE_CHECKING or do_validate)
|
| 177 |
+
and not isinstance(output, output_type)):
|
| 178 |
+
raise TypeError(f"Expected output of type {output_type}, "
|
| 179 |
+
f"but found type {type(output)}")
|
| 180 |
+
|
| 181 |
+
return cast(_O, output)
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def validate_outputs(
|
| 185 |
+
cls,
|
| 186 |
+
outputs: GenericSequence[object],
|
| 187 |
+
output_type: Type[_O],
|
| 188 |
+
) -> List[_O]:
|
| 189 |
+
do_validate = cls.DO_VALIDATE_OUTPUT
|
| 190 |
+
|
| 191 |
+
outputs_: List[_O]
|
| 192 |
+
if TYPE_CHECKING or do_validate:
|
| 193 |
+
outputs_ = []
|
| 194 |
+
for output in outputs:
|
| 195 |
+
if not isinstance(output, output_type):
|
| 196 |
+
raise TypeError(f"Expected output of type {output_type}, "
|
| 197 |
+
f"but found type {type(output)}")
|
| 198 |
+
|
| 199 |
+
outputs_.append(output)
|
| 200 |
+
else:
|
| 201 |
+
outputs_ = outputs
|
| 202 |
+
|
| 203 |
+
return outputs_
|
| 204 |
+
|
| 205 |
+
tokenizer: Optional[BaseTokenizerGroup]
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
vllm_config: VllmConfig,
|
| 210 |
+
executor_class: Type[ExecutorBase],
|
| 211 |
+
log_stats: bool,
|
| 212 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
| 213 |
+
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
| 214 |
+
input_registry: InputRegistry = INPUT_REGISTRY,
|
| 215 |
+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
| 216 |
+
use_cached_outputs: bool = False,
|
| 217 |
+
) -> None:
|
| 218 |
+
|
| 219 |
+
self.vllm_config = vllm_config
|
| 220 |
+
self.model_config = vllm_config.model_config
|
| 221 |
+
self.cache_config = vllm_config.cache_config
|
| 222 |
+
self.lora_config = vllm_config.lora_config
|
| 223 |
+
self.parallel_config = vllm_config.parallel_config
|
| 224 |
+
self.scheduler_config = vllm_config.scheduler_config
|
| 225 |
+
self.device_config = vllm_config.device_config
|
| 226 |
+
self.speculative_config = vllm_config.speculative_config # noqa
|
| 227 |
+
self.load_config = vllm_config.load_config
|
| 228 |
+
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
| 229 |
+
)
|
| 230 |
+
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
|
| 231 |
+
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
logger.info(
|
| 235 |
+
"Initializing a V0 LLM engine (v%s) with config: %s, "
|
| 236 |
+
"use_cached_outputs=%s, ",
|
| 237 |
+
VLLM_VERSION,
|
| 238 |
+
vllm_config,
|
| 239 |
+
use_cached_outputs,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self.log_stats = log_stats
|
| 243 |
+
self.use_cached_outputs = use_cached_outputs
|
| 244 |
+
|
| 245 |
+
if not self.model_config.skip_tokenizer_init:
|
| 246 |
+
self.tokenizer = self._init_tokenizer()
|
| 247 |
+
self.detokenizer = Detokenizer(self.tokenizer)
|
| 248 |
+
tokenizer_group = self.get_tokenizer_group()
|
| 249 |
+
else:
|
| 250 |
+
self.tokenizer = None
|
| 251 |
+
self.detokenizer = None
|
| 252 |
+
tokenizer_group = None
|
| 253 |
+
|
| 254 |
+
# Ensure that the function doesn't contain a reference to self,
|
| 255 |
+
# to avoid engine GC issues
|
| 256 |
+
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
|
| 257 |
+
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
| 258 |
+
"make sure skip_tokenizer_init is False")
|
| 259 |
+
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
| 260 |
+
|
| 261 |
+
self.seq_counter = Counter()
|
| 262 |
+
self.generation_config_fields = (
|
| 263 |
+
self.model_config.try_get_generation_config())
|
| 264 |
+
|
| 265 |
+
self.input_preprocessor = InputPreprocessor(self.model_config,
|
| 266 |
+
self.tokenizer,
|
| 267 |
+
mm_registry)
|
| 268 |
+
|
| 269 |
+
self.input_registry = input_registry
|
| 270 |
+
self.input_processor = input_registry.create_input_processor(
|
| 271 |
+
self.model_config)
|
| 272 |
+
|
| 273 |
+
self.model_executor = executor_class(vllm_config=vllm_config, )
|
| 274 |
+
|
| 275 |
+
if self.model_config.runner_type != "pooling":
|
| 276 |
+
self._initialize_kv_caches()
|
| 277 |
+
|
| 278 |
+
# If usage stat is enabled, collect relevant info.
|
| 279 |
+
if is_usage_stats_enabled():
|
| 280 |
+
from vllm.model_executor.model_loader import (
|
| 281 |
+
get_architecture_class_name)
|
| 282 |
+
usage_message.report_usage(
|
| 283 |
+
get_architecture_class_name(self.model_config),
|
| 284 |
+
usage_context,
|
| 285 |
+
extra_kvs={
|
| 286 |
+
# Common configuration
|
| 287 |
+
"dtype":
|
| 288 |
+
str(self.model_config.dtype),
|
| 289 |
+
"tensor_parallel_size":
|
| 290 |
+
self.parallel_config.tensor_parallel_size,
|
| 291 |
+
"block_size":
|
| 292 |
+
self.cache_config.block_size,
|
| 293 |
+
"gpu_memory_utilization":
|
| 294 |
+
self.cache_config.gpu_memory_utilization,
|
| 295 |
+
|
| 296 |
+
# Quantization
|
| 297 |
+
"quantization":
|
| 298 |
+
self.model_config.quantization,
|
| 299 |
+
"kv_cache_dtype":
|
| 300 |
+
str(self.cache_config.cache_dtype),
|
| 301 |
+
|
| 302 |
+
# Feature flags
|
| 303 |
+
"enable_lora":
|
| 304 |
+
bool(self.lora_config),
|
| 305 |
+
"enable_prompt_adapter":
|
| 306 |
+
bool(self.prompt_adapter_config),
|
| 307 |
+
"enable_prefix_caching":
|
| 308 |
+
self.cache_config.enable_prefix_caching,
|
| 309 |
+
"enforce_eager":
|
| 310 |
+
self.model_config.enforce_eager,
|
| 311 |
+
"disable_custom_all_reduce":
|
| 312 |
+
self.parallel_config.disable_custom_all_reduce,
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
if self.tokenizer:
|
| 316 |
+
# Ping the tokenizer to ensure liveness if it runs in a
|
| 317 |
+
# different process.
|
| 318 |
+
self.tokenizer.ping()
|
| 319 |
+
|
| 320 |
+
self.cached_scheduler_outputs = [
|
| 321 |
+
SchedulerOutputState()
|
| 322 |
+
for _ in range(self.parallel_config.pipeline_parallel_size)
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
self.scheduler_contexts = [
|
| 326 |
+
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
| 327 |
+
multi_step_stream_outputs)
|
| 328 |
+
for _ in range(self.parallel_config.pipeline_parallel_size)
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
if self.model_config.use_async_output_proc:
|
| 332 |
+
process_model_outputs = weak_bind(self._process_model_outputs)
|
| 333 |
+
|
| 334 |
+
self.async_callbacks = [
|
| 335 |
+
partial(process_model_outputs,
|
| 336 |
+
ctx=self.scheduler_contexts[v_id])
|
| 337 |
+
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
| 338 |
+
]
|
| 339 |
+
else:
|
| 340 |
+
self.async_callbacks = []
|
| 341 |
+
|
| 342 |
+
# Currently used by AsyncLLMEngine to ensure quick append
|
| 343 |
+
# of request outputs to asyncio queues
|
| 344 |
+
self.process_request_outputs_callback: Optional[Callable] = None
|
| 345 |
+
|
| 346 |
+
# Create the scheduler.
|
| 347 |
+
# NOTE: the cache_config here have been updated with the numbers of
|
| 348 |
+
# GPU and CPU blocks, which are profiled in the distributed executor.
|
| 349 |
+
self.scheduler = [
|
| 350 |
+
Scheduler(
|
| 351 |
+
self.scheduler_config, self.cache_config, self.lora_config,
|
| 352 |
+
self.parallel_config.pipeline_parallel_size,
|
| 353 |
+
self.async_callbacks[v_id]
|
| 354 |
+
if self.model_config.use_async_output_proc else None)
|
| 355 |
+
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
# Metric Logging.
|
| 359 |
+
if self.log_stats:
|
| 360 |
+
if stat_loggers is not None:
|
| 361 |
+
self.stat_loggers = stat_loggers
|
| 362 |
+
else:
|
| 363 |
+
# Lazy import for prometheus multiprocessing.
|
| 364 |
+
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
| 365 |
+
# before prometheus_client is imported.
|
| 366 |
+
# See https://prometheus.github.io/client_python/multiprocess/
|
| 367 |
+
from vllm.engine.metrics import (LoggingStatLogger,
|
| 368 |
+
PrometheusStatLogger)
|
| 369 |
+
|
| 370 |
+
self.stat_loggers = {
|
| 371 |
+
"logging":
|
| 372 |
+
LoggingStatLogger(
|
| 373 |
+
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
| 374 |
+
vllm_config=vllm_config),
|
| 375 |
+
"prometheus":
|
| 376 |
+
PrometheusStatLogger(
|
| 377 |
+
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
| 378 |
+
labels=dict(
|
| 379 |
+
model_name=self.model_config.served_model_name),
|
| 380 |
+
vllm_config=vllm_config),
|
| 381 |
+
}
|
| 382 |
+
self.stat_loggers["prometheus"].info("cache_config",
|
| 383 |
+
self.cache_config)
|
| 384 |
+
|
| 385 |
+
self.tracer = None
|
| 386 |
+
if self.observability_config.otlp_traces_endpoint:
|
| 387 |
+
self.tracer = init_tracer(
|
| 388 |
+
"vllm.llm_engine",
|
| 389 |
+
self.observability_config.otlp_traces_endpoint)
|
| 390 |
+
|
| 391 |
+
# Create sequence output processor, e.g. for beam search or
|
| 392 |
+
# speculative decoding.
|
| 393 |
+
self.output_processor = (
|
| 394 |
+
SequenceGroupOutputProcessor.create_output_processor(
|
| 395 |
+
self.scheduler_config,
|
| 396 |
+
self.detokenizer,
|
| 397 |
+
self.scheduler,
|
| 398 |
+
self.seq_counter,
|
| 399 |
+
get_tokenizer_for_seq,
|
| 400 |
+
stop_checker=StopChecker(
|
| 401 |
+
self.scheduler_config.max_model_len,
|
| 402 |
+
get_tokenizer_for_seq,
|
| 403 |
+
),
|
| 404 |
+
))
|
| 405 |
+
|
| 406 |
+
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
| 407 |
+
|
| 408 |
+
def _initialize_kv_caches(self) -> None:
|
| 409 |
+
"""Initialize the KV cache in the worker(s).
|
| 410 |
+
|
| 411 |
+
The workers will determine the number of blocks in both the GPU cache
|
| 412 |
+
and the swap CPU cache.
|
| 413 |
+
"""
|
| 414 |
+
start = time.time()
|
| 415 |
+
num_gpu_blocks, num_cpu_blocks = (
|
| 416 |
+
self.model_executor.determine_num_available_blocks())
|
| 417 |
+
|
| 418 |
+
if self.cache_config.num_gpu_blocks_override is not None:
|
| 419 |
+
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
| 420 |
+
logger.info(
|
| 421 |
+
"Overriding num_gpu_blocks=%d with "
|
| 422 |
+
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
| 423 |
+
num_gpu_blocks_override)
|
| 424 |
+
num_gpu_blocks = num_gpu_blocks_override
|
| 425 |
+
|
| 426 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
| 427 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
| 428 |
+
|
| 429 |
+
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
| 430 |
+
elapsed = time.time() - start
|
| 431 |
+
logger.info(("init engine (profile, create kv cache, "
|
| 432 |
+
"warmup model) took %.2f seconds"), elapsed)
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def _get_executor_cls(cls,
|
| 436 |
+
engine_config: VllmConfig) -> Type[ExecutorBase]:
|
| 437 |
+
distributed_executor_backend = (
|
| 438 |
+
engine_config.parallel_config.distributed_executor_backend)
|
| 439 |
+
# Initialize the cluster and specify the executor class.
|
| 440 |
+
if isinstance(distributed_executor_backend, type):
|
| 441 |
+
if not issubclass(distributed_executor_backend, ExecutorBase):
|
| 442 |
+
raise TypeError(
|
| 443 |
+
"distributed_executor_backend must be a subclass of "
|
| 444 |
+
f"ExecutorBase. Got {distributed_executor_backend}.")
|
| 445 |
+
executor_class = distributed_executor_backend
|
| 446 |
+
elif engine_config.parallel_config.world_size > 1:
|
| 447 |
+
if distributed_executor_backend == "ray":
|
| 448 |
+
from vllm.executor.ray_distributed_executor import (
|
| 449 |
+
RayDistributedExecutor)
|
| 450 |
+
executor_class = RayDistributedExecutor
|
| 451 |
+
elif distributed_executor_backend == "mp":
|
| 452 |
+
from vllm.executor.mp_distributed_executor import (
|
| 453 |
+
MultiprocessingDistributedExecutor)
|
| 454 |
+
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
|
| 455 |
+
"multiprocessing distributed executor backend does not "
|
| 456 |
+
"support VLLM_USE_RAY_SPMD_WORKER=1")
|
| 457 |
+
executor_class = MultiprocessingDistributedExecutor
|
| 458 |
+
elif distributed_executor_backend == "uni":
|
| 459 |
+
# JAX-style, single-process, multi-device executor.
|
| 460 |
+
from vllm.executor.uniproc_executor import UniProcExecutor
|
| 461 |
+
executor_class = UniProcExecutor
|
| 462 |
+
elif distributed_executor_backend == "external_launcher":
|
| 463 |
+
# executor with external launcher
|
| 464 |
+
from vllm.executor.uniproc_executor import ( # noqa
|
| 465 |
+
ExecutorWithExternalLauncher)
|
| 466 |
+
executor_class = ExecutorWithExternalLauncher
|
| 467 |
+
else:
|
| 468 |
+
from vllm.executor.uniproc_executor import UniProcExecutor
|
| 469 |
+
executor_class = UniProcExecutor
|
| 470 |
+
return executor_class
|
| 471 |
+
|
| 472 |
+
@classmethod
|
| 473 |
+
def from_engine_args(
|
| 474 |
+
cls,
|
| 475 |
+
engine_args: EngineArgs,
|
| 476 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
| 477 |
+
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
| 478 |
+
) -> "LLMEngine":
|
| 479 |
+
"""Creates an LLM engine from the engine arguments."""
|
| 480 |
+
# Create the engine configs.
|
| 481 |
+
engine_config = engine_args.create_engine_config(usage_context)
|
| 482 |
+
executor_class = cls._get_executor_cls(engine_config)
|
| 483 |
+
# Create the LLM engine.
|
| 484 |
+
engine = cls(
|
| 485 |
+
vllm_config=engine_config,
|
| 486 |
+
executor_class=executor_class,
|
| 487 |
+
log_stats=not engine_args.disable_log_stats,
|
| 488 |
+
usage_context=usage_context,
|
| 489 |
+
stat_loggers=stat_loggers,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
return engine
|
| 493 |
+
|
| 494 |
+
def __reduce__(self):
|
| 495 |
+
# This is to ensure that the LLMEngine is not referenced in
|
| 496 |
+
# the closure used to initialize Ray worker actors
|
| 497 |
+
raise RuntimeError("LLMEngine should not be pickled!")
|
| 498 |
+
|
| 499 |
+
def __del__(self):
|
| 500 |
+
# Shutdown model executor when engine is garbage collected
|
| 501 |
+
# Use getattr since __init__ can fail before the field is set
|
| 502 |
+
if model_executor := getattr(self, "model_executor", None):
|
| 503 |
+
model_executor.shutdown()
|
| 504 |
+
|
| 505 |
+
def get_tokenizer_group(
|
| 506 |
+
self,
|
| 507 |
+
group_type: Type[_G] = BaseTokenizerGroup,
|
| 508 |
+
) -> _G:
|
| 509 |
+
tokenizer_group = self.tokenizer
|
| 510 |
+
|
| 511 |
+
if tokenizer_group is None:
|
| 512 |
+
raise ValueError("Unable to get tokenizer because "
|
| 513 |
+
"skip_tokenizer_init is True")
|
| 514 |
+
if not isinstance(tokenizer_group, group_type):
|
| 515 |
+
raise TypeError("Invalid type of tokenizer group. "
|
| 516 |
+
f"Expected type: {group_type}, but "
|
| 517 |
+
f"found type: {type(tokenizer_group)}")
|
| 518 |
+
|
| 519 |
+
return tokenizer_group
|
| 520 |
+
|
| 521 |
+
def get_tokenizer(
|
| 522 |
+
self,
|
| 523 |
+
lora_request: Optional[LoRARequest] = None,
|
| 524 |
+
) -> AnyTokenizer:
|
| 525 |
+
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
| 526 |
+
|
| 527 |
+
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
| 528 |
+
return init_tokenizer_from_configs(
|
| 529 |
+
model_config=self.model_config,
|
| 530 |
+
scheduler_config=self.scheduler_config,
|
| 531 |
+
parallel_config=self.parallel_config,
|
| 532 |
+
lora_config=self.lora_config)
|
| 533 |
+
|
| 534 |
+
def _verify_args(self) -> None:
|
| 535 |
+
self.model_config.verify_with_parallel_config(self.parallel_config)
|
| 536 |
+
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
| 537 |
+
if self.lora_config:
|
| 538 |
+
self.lora_config.verify_with_model_config(self.model_config)
|
| 539 |
+
self.lora_config.verify_with_scheduler_config(
|
| 540 |
+
self.scheduler_config)
|
| 541 |
+
if self.prompt_adapter_config:
|
| 542 |
+
self.prompt_adapter_config.verify_with_model_config(
|
| 543 |
+
self.model_config)
|
| 544 |
+
|
| 545 |
+
def _add_processed_request(
|
| 546 |
+
self,
|
| 547 |
+
request_id: str,
|
| 548 |
+
processed_inputs: ProcessorInputs,
|
| 549 |
+
params: Union[SamplingParams, PoolingParams],
|
| 550 |
+
arrival_time: float,
|
| 551 |
+
lora_request: Optional[LoRARequest],
|
| 552 |
+
prompt_adapter_request: Optional[PromptAdapterRequest],
|
| 553 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 554 |
+
priority: int = 0,
|
| 555 |
+
) -> Optional[SequenceGroup]:
|
| 556 |
+
"""Add a processed request to the engine's request pool.
|
| 557 |
+
return the created sequence group.
|
| 558 |
+
"""
|
| 559 |
+
if isinstance(params, SamplingParams) and params.n > 1:
|
| 560 |
+
ParallelSampleSequenceGroup.add_request(
|
| 561 |
+
request_id,
|
| 562 |
+
self,
|
| 563 |
+
params,
|
| 564 |
+
processed_inputs=processed_inputs,
|
| 565 |
+
arrival_time=arrival_time,
|
| 566 |
+
lora_request=lora_request,
|
| 567 |
+
trace_headers=trace_headers,
|
| 568 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 569 |
+
priority=priority,
|
| 570 |
+
)
|
| 571 |
+
return None
|
| 572 |
+
|
| 573 |
+
self._validate_model_inputs(processed_inputs, lora_request)
|
| 574 |
+
# Create the sequences.
|
| 575 |
+
block_size = self.cache_config.block_size
|
| 576 |
+
seq_id = next(self.seq_counter)
|
| 577 |
+
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
| 578 |
+
|
| 579 |
+
if is_encoder_decoder_inputs(processed_inputs):
|
| 580 |
+
decoder_inputs = processed_inputs["decoder"]
|
| 581 |
+
encoder_inputs = processed_inputs["encoder"]
|
| 582 |
+
else:
|
| 583 |
+
decoder_inputs = processed_inputs
|
| 584 |
+
encoder_inputs = None
|
| 585 |
+
|
| 586 |
+
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
| 587 |
+
lora_request, prompt_adapter_request)
|
| 588 |
+
|
| 589 |
+
encoder_seq = (None if encoder_inputs is None else Sequence(
|
| 590 |
+
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
| 591 |
+
prompt_adapter_request))
|
| 592 |
+
|
| 593 |
+
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
| 594 |
+
if isinstance(params, SamplingParams):
|
| 595 |
+
seq_group = self._create_sequence_group_with_sampling(
|
| 596 |
+
request_id,
|
| 597 |
+
seq,
|
| 598 |
+
params,
|
| 599 |
+
arrival_time=arrival_time,
|
| 600 |
+
lora_request=lora_request,
|
| 601 |
+
trace_headers=trace_headers,
|
| 602 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 603 |
+
encoder_seq=encoder_seq,
|
| 604 |
+
priority=priority)
|
| 605 |
+
elif isinstance(params, PoolingParams):
|
| 606 |
+
seq_group = self._create_sequence_group_with_pooling(
|
| 607 |
+
request_id,
|
| 608 |
+
seq,
|
| 609 |
+
params,
|
| 610 |
+
arrival_time=arrival_time,
|
| 611 |
+
lora_request=lora_request,
|
| 612 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 613 |
+
encoder_seq=encoder_seq,
|
| 614 |
+
priority=priority)
|
| 615 |
+
else:
|
| 616 |
+
raise ValueError(
|
| 617 |
+
"Either SamplingParams or PoolingParams must be provided.")
|
| 618 |
+
|
| 619 |
+
# Add the sequence group to the scheduler with least unfinished seqs.
|
| 620 |
+
costs = [
|
| 621 |
+
scheduler.get_num_unfinished_seq_groups()
|
| 622 |
+
for scheduler in self.scheduler
|
| 623 |
+
]
|
| 624 |
+
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
| 625 |
+
min_cost_scheduler.add_seq_group(seq_group)
|
| 626 |
+
|
| 627 |
+
return seq_group
|
| 628 |
+
|
| 629 |
+
def stop_remote_worker_execution_loop(self) -> None:
|
| 630 |
+
self.model_executor.stop_remote_worker_execution_loop()
|
| 631 |
+
|
| 632 |
+
@overload
|
| 633 |
+
def add_request(
|
| 634 |
+
self,
|
| 635 |
+
request_id: str,
|
| 636 |
+
prompt: PromptType,
|
| 637 |
+
params: Union[SamplingParams, PoolingParams],
|
| 638 |
+
arrival_time: Optional[float] = None,
|
| 639 |
+
lora_request: Optional[LoRARequest] = None,
|
| 640 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 641 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 642 |
+
priority: int = 0,
|
| 643 |
+
) -> None:
|
| 644 |
+
...
|
| 645 |
+
|
| 646 |
+
@overload
|
| 647 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 648 |
+
def add_request(
|
| 649 |
+
self,
|
| 650 |
+
request_id: str,
|
| 651 |
+
*,
|
| 652 |
+
inputs: PromptType,
|
| 653 |
+
params: Union[SamplingParams, PoolingParams],
|
| 654 |
+
arrival_time: Optional[float] = None,
|
| 655 |
+
lora_request: Optional[LoRARequest] = None,
|
| 656 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 657 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 658 |
+
priority: int = 0,
|
| 659 |
+
) -> None:
|
| 660 |
+
...
|
| 661 |
+
|
| 662 |
+
@deprecate_kwargs(
|
| 663 |
+
"inputs",
|
| 664 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 665 |
+
)
|
| 666 |
+
def add_request(
|
| 667 |
+
self,
|
| 668 |
+
request_id: str,
|
| 669 |
+
prompt: Optional[PromptType] = None,
|
| 670 |
+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
| 671 |
+
arrival_time: Optional[float] = None,
|
| 672 |
+
lora_request: Optional[LoRARequest] = None,
|
| 673 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 674 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 675 |
+
priority: int = 0,
|
| 676 |
+
*,
|
| 677 |
+
inputs: Optional[PromptType] = None, # DEPRECATED
|
| 678 |
+
) -> None:
|
| 679 |
+
"""Add a request to the engine's request pool.
|
| 680 |
+
|
| 681 |
+
The request is added to the request pool and will be processed by the
|
| 682 |
+
scheduler as `engine.step()` is called. The exact scheduling policy is
|
| 683 |
+
determined by the scheduler.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
request_id: The unique ID of the request.
|
| 687 |
+
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
| 688 |
+
for more details about the format of each input.
|
| 689 |
+
params: Parameters for sampling or pooling.
|
| 690 |
+
:class:`~vllm.SamplingParams` for text generation.
|
| 691 |
+
:class:`~vllm.PoolingParams` for pooling.
|
| 692 |
+
arrival_time: The arrival time of the request. If None, we use
|
| 693 |
+
the current monotonic time.
|
| 694 |
+
lora_request: The LoRA request to add.
|
| 695 |
+
trace_headers: OpenTelemetry trace headers.
|
| 696 |
+
prompt_adapter_request: The prompt adapter request to add.
|
| 697 |
+
priority: The priority of the request.
|
| 698 |
+
Only applicable with priority scheduling.
|
| 699 |
+
|
| 700 |
+
Details:
|
| 701 |
+
- Set arrival_time to the current time if it is None.
|
| 702 |
+
- Set prompt_token_ids to the encoded prompt if it is None.
|
| 703 |
+
- Create `n` number of :class:`~vllm.Sequence` objects.
|
| 704 |
+
- Create a :class:`~vllm.SequenceGroup` object
|
| 705 |
+
from the list of :class:`~vllm.Sequence`.
|
| 706 |
+
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
| 707 |
+
|
| 708 |
+
Example:
|
| 709 |
+
>>> # initialize engine
|
| 710 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
| 711 |
+
>>> # set request arguments
|
| 712 |
+
>>> example_prompt = "Who is the president of the United States?"
|
| 713 |
+
>>> sampling_params = SamplingParams(temperature=0.0)
|
| 714 |
+
>>> request_id = 0
|
| 715 |
+
>>>
|
| 716 |
+
>>> # add the request to the engine
|
| 717 |
+
>>> engine.add_request(
|
| 718 |
+
>>> str(request_id),
|
| 719 |
+
>>> example_prompt,
|
| 720 |
+
>>> SamplingParams(temperature=0.0))
|
| 721 |
+
>>> # continue the request processing
|
| 722 |
+
>>> ...
|
| 723 |
+
"""
|
| 724 |
+
if inputs is not None:
|
| 725 |
+
prompt = inputs
|
| 726 |
+
assert prompt is not None and params is not None
|
| 727 |
+
|
| 728 |
+
if lora_request is not None and not self.lora_config:
|
| 729 |
+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
| 730 |
+
"not enabled!")
|
| 731 |
+
|
| 732 |
+
if priority != 0 and not self.scheduler_config.policy == "priority":
|
| 733 |
+
raise ValueError(f"Got priority {priority} but "
|
| 734 |
+
"Priority scheduling is not enabled.")
|
| 735 |
+
|
| 736 |
+
if isinstance(params, SamplingParams) \
|
| 737 |
+
and (params.guided_decoding or params.logits_processors) \
|
| 738 |
+
and self.scheduler_config.num_scheduler_steps > 1:
|
| 739 |
+
raise ValueError(
|
| 740 |
+
"Guided decoding and logits processors are not supported "
|
| 741 |
+
"in multi-step decoding")
|
| 742 |
+
|
| 743 |
+
if arrival_time is None:
|
| 744 |
+
arrival_time = time.time()
|
| 745 |
+
|
| 746 |
+
if self.tokenizer is not None:
|
| 747 |
+
self._validate_token_prompt(
|
| 748 |
+
prompt,
|
| 749 |
+
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
| 750 |
+
|
| 751 |
+
preprocessed_inputs = self.input_preprocessor.preprocess(
|
| 752 |
+
prompt,
|
| 753 |
+
request_id=request_id,
|
| 754 |
+
lora_request=lora_request,
|
| 755 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 756 |
+
)
|
| 757 |
+
processed_inputs = self.input_processor(preprocessed_inputs)
|
| 758 |
+
|
| 759 |
+
self._add_processed_request(
|
| 760 |
+
request_id=request_id,
|
| 761 |
+
processed_inputs=processed_inputs,
|
| 762 |
+
params=params,
|
| 763 |
+
arrival_time=arrival_time,
|
| 764 |
+
lora_request=lora_request,
|
| 765 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 766 |
+
trace_headers=trace_headers,
|
| 767 |
+
priority=priority,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
def _validate_token_prompt(self, prompt: PromptType,
|
| 771 |
+
tokenizer: AnyTokenizer):
|
| 772 |
+
# Guard against out-of-vocab tokens.
|
| 773 |
+
# For some tokenizers, tokenizer.decode will happily return empty text
|
| 774 |
+
# for token ids that are out of vocab, and we don't detect token ids
|
| 775 |
+
# that are greater than the max token id before running the model.
|
| 776 |
+
# However, these token ids will later crash a cuda kernel at runtime
|
| 777 |
+
# with an index out of bounds error. This will crash the entire engine.
|
| 778 |
+
# This needs to happen before multimodal input pre-processing, which
|
| 779 |
+
# may add dummy <image> tokens that aren't part of the tokenizer's
|
| 780 |
+
# vocabulary.
|
| 781 |
+
if is_token_prompt(prompt):
|
| 782 |
+
prompt_ids = prompt["prompt_token_ids"]
|
| 783 |
+
if len(prompt_ids) == 0:
|
| 784 |
+
# Empty prompt check is handled later
|
| 785 |
+
return
|
| 786 |
+
max_input_id = max(prompt_ids)
|
| 787 |
+
if max_input_id > tokenizer.max_token_id:
|
| 788 |
+
raise ValueError(
|
| 789 |
+
"Token id {} is out of vocabulary".format(max_input_id))
|
| 790 |
+
|
| 791 |
+
def _create_sequence_group_with_sampling(
|
| 792 |
+
self,
|
| 793 |
+
request_id: str,
|
| 794 |
+
seq: Sequence,
|
| 795 |
+
sampling_params: SamplingParams,
|
| 796 |
+
arrival_time: float,
|
| 797 |
+
lora_request: Optional[LoRARequest],
|
| 798 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 799 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 800 |
+
encoder_seq: Optional[Sequence] = None,
|
| 801 |
+
priority: int = 0,
|
| 802 |
+
) -> SequenceGroup:
|
| 803 |
+
"""Creates a SequenceGroup with SamplingParams."""
|
| 804 |
+
max_logprobs = self.get_model_config().max_logprobs
|
| 805 |
+
if (sampling_params.logprobs
|
| 806 |
+
and sampling_params.logprobs > max_logprobs) or (
|
| 807 |
+
sampling_params.prompt_logprobs
|
| 808 |
+
and sampling_params.prompt_logprobs > max_logprobs):
|
| 809 |
+
raise ValueError(f"Cannot request more than "
|
| 810 |
+
f"{max_logprobs} logprobs.")
|
| 811 |
+
|
| 812 |
+
sampling_params = self._build_logits_processors(
|
| 813 |
+
sampling_params, lora_request)
|
| 814 |
+
|
| 815 |
+
# Defensive copy of SamplingParams, which are used by the sampler,
|
| 816 |
+
# this doesn't deep-copy LogitsProcessor objects
|
| 817 |
+
sampling_params = sampling_params.clone()
|
| 818 |
+
|
| 819 |
+
sampling_params.update_from_generation_config(
|
| 820 |
+
self.generation_config_fields, seq.eos_token_id)
|
| 821 |
+
|
| 822 |
+
# Create the sequence group.
|
| 823 |
+
seq_group = SequenceGroup(
|
| 824 |
+
request_id=request_id,
|
| 825 |
+
seqs=[seq],
|
| 826 |
+
arrival_time=arrival_time,
|
| 827 |
+
sampling_params=sampling_params,
|
| 828 |
+
lora_request=lora_request,
|
| 829 |
+
trace_headers=trace_headers,
|
| 830 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 831 |
+
encoder_seq=encoder_seq,
|
| 832 |
+
priority=priority)
|
| 833 |
+
|
| 834 |
+
return seq_group
|
| 835 |
+
|
| 836 |
+
def _create_sequence_group_with_pooling(
|
| 837 |
+
self,
|
| 838 |
+
request_id: str,
|
| 839 |
+
seq: Sequence,
|
| 840 |
+
pooling_params: PoolingParams,
|
| 841 |
+
arrival_time: float,
|
| 842 |
+
lora_request: Optional[LoRARequest],
|
| 843 |
+
prompt_adapter_request: Optional[PromptAdapterRequest],
|
| 844 |
+
encoder_seq: Optional[Sequence] = None,
|
| 845 |
+
priority: int = 0,
|
| 846 |
+
) -> SequenceGroup:
|
| 847 |
+
"""Creates a SequenceGroup with PoolingParams."""
|
| 848 |
+
# Defensive copy of PoolingParams, which are used by the pooler
|
| 849 |
+
pooling_params = pooling_params.clone()
|
| 850 |
+
# Create the sequence group.
|
| 851 |
+
seq_group = SequenceGroup(
|
| 852 |
+
request_id=request_id,
|
| 853 |
+
seqs=[seq],
|
| 854 |
+
arrival_time=arrival_time,
|
| 855 |
+
lora_request=lora_request,
|
| 856 |
+
pooling_params=pooling_params,
|
| 857 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 858 |
+
encoder_seq=encoder_seq,
|
| 859 |
+
priority=priority)
|
| 860 |
+
return seq_group
|
| 861 |
+
|
| 862 |
+
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
| 863 |
+
"""Aborts a request(s) with the given ID.
|
| 864 |
+
|
| 865 |
+
Args:
|
| 866 |
+
request_id: The ID(s) of the request to abort.
|
| 867 |
+
|
| 868 |
+
Details:
|
| 869 |
+
- Refer to the
|
| 870 |
+
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
| 871 |
+
from class :class:`~vllm.core.scheduler.Scheduler`.
|
| 872 |
+
|
| 873 |
+
Example:
|
| 874 |
+
>>> # initialize engine and add a request with request_id
|
| 875 |
+
>>> request_id = str(0)
|
| 876 |
+
>>> # abort the request
|
| 877 |
+
>>> engine.abort_request(request_id)
|
| 878 |
+
"""
|
| 879 |
+
for scheduler in self.scheduler:
|
| 880 |
+
scheduler.abort_seq_group(request_id)
|
| 881 |
+
|
| 882 |
+
def get_model_config(self) -> ModelConfig:
|
| 883 |
+
"""Gets the model configuration."""
|
| 884 |
+
return self.model_config
|
| 885 |
+
|
| 886 |
+
def get_parallel_config(self) -> ParallelConfig:
|
| 887 |
+
"""Gets the parallel configuration."""
|
| 888 |
+
return self.parallel_config
|
| 889 |
+
|
| 890 |
+
def get_decoding_config(self) -> DecodingConfig:
|
| 891 |
+
"""Gets the decoding configuration."""
|
| 892 |
+
return self.decoding_config
|
| 893 |
+
|
| 894 |
+
def get_scheduler_config(self) -> SchedulerConfig:
|
| 895 |
+
"""Gets the scheduler configuration."""
|
| 896 |
+
return self.scheduler_config
|
| 897 |
+
|
| 898 |
+
def get_lora_config(self) -> LoRAConfig:
|
| 899 |
+
"""Gets the LoRA configuration."""
|
| 900 |
+
return self.lora_config
|
| 901 |
+
|
| 902 |
+
def get_num_unfinished_requests(self) -> int:
|
| 903 |
+
"""Gets the number of unfinished requests."""
|
| 904 |
+
return sum(scheduler.get_num_unfinished_seq_groups()
|
| 905 |
+
for scheduler in self.scheduler)
|
| 906 |
+
|
| 907 |
+
def has_unfinished_requests(self) -> bool:
|
| 908 |
+
"""Returns True if there are unfinished requests."""
|
| 909 |
+
return any(scheduler.has_unfinished_seqs()
|
| 910 |
+
for scheduler in self.scheduler)
|
| 911 |
+
|
| 912 |
+
def has_unfinished_requests_for_virtual_engine(
|
| 913 |
+
self, virtual_engine: int) -> bool:
|
| 914 |
+
"""
|
| 915 |
+
Returns True if there are unfinished requests for the virtual engine.
|
| 916 |
+
"""
|
| 917 |
+
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
| 918 |
+
|
| 919 |
+
def reset_prefix_cache(self) -> bool:
|
| 920 |
+
"""Reset prefix cache for all devices."""
|
| 921 |
+
|
| 922 |
+
success = True
|
| 923 |
+
for scheduler in self.scheduler:
|
| 924 |
+
success = success and scheduler.reset_prefix_cache()
|
| 925 |
+
return success
|
| 926 |
+
|
| 927 |
+
@staticmethod
|
| 928 |
+
def _process_sequence_group_outputs(
|
| 929 |
+
seq_group: SequenceGroup,
|
| 930 |
+
outputs: List[PoolingSequenceGroupOutput],
|
| 931 |
+
) -> None:
|
| 932 |
+
seq_group.pooled_data = outputs[0].data
|
| 933 |
+
|
| 934 |
+
for seq in seq_group.get_seqs():
|
| 935 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
| 936 |
+
|
| 937 |
+
return
|
| 938 |
+
|
| 939 |
+
def _update_num_computed_tokens_for_multi_step_prefill(
|
| 940 |
+
self, seq_group: SequenceGroup,
|
| 941 |
+
seq_group_meta: SequenceGroupMetadata,
|
| 942 |
+
is_first_step_output: Optional[bool]):
|
| 943 |
+
"""
|
| 944 |
+
This function updates num_computed_tokens for prompt sequences
|
| 945 |
+
when Multi-Step is enabled.
|
| 946 |
+
|
| 947 |
+
seq_group: SequenceGroup to update the num_computed_tokens for.
|
| 948 |
+
seq_group_meta: Metadata of the given SequenceGroup.
|
| 949 |
+
is_first_step_output: Optional[bool] -
|
| 950 |
+
When available, is_first_step_output indicates if the appended
|
| 951 |
+
output token is the output of the first-step in multi-step.
|
| 952 |
+
A value of None indicates that outputs from all steps in
|
| 953 |
+
in multi-step are submitted in a single burst.
|
| 954 |
+
"""
|
| 955 |
+
|
| 956 |
+
assert self.scheduler_config.is_multi_step
|
| 957 |
+
|
| 958 |
+
if not seq_group_meta.is_prompt:
|
| 959 |
+
# num_computed_token updates for multi-step decodes happen after
|
| 960 |
+
# the tokens are appended to the sequence.
|
| 961 |
+
return
|
| 962 |
+
|
| 963 |
+
do_update: bool = False
|
| 964 |
+
if self.scheduler_config.chunked_prefill_enabled:
|
| 965 |
+
# In multi-step + chunked-prefill case, the prompt sequences
|
| 966 |
+
# that are scheduled are fully processed in the first step.
|
| 967 |
+
do_update = is_first_step_output is None or is_first_step_output
|
| 968 |
+
else:
|
| 969 |
+
# Normal multi-step decoding case. In this case prompt-sequences
|
| 970 |
+
# are actually single-stepped. Always update in this case.
|
| 971 |
+
assert seq_group.state.num_steps == 1
|
| 972 |
+
do_update = True
|
| 973 |
+
|
| 974 |
+
if do_update:
|
| 975 |
+
seq_group.update_num_computed_tokens(
|
| 976 |
+
seq_group_meta.token_chunk_size)
|
| 977 |
+
|
| 978 |
+
def _process_model_outputs(self,
|
| 979 |
+
ctx: SchedulerContext,
|
| 980 |
+
request_id: Optional[str] = None) -> None:
|
| 981 |
+
"""Apply the model output to the sequences in the scheduled seq groups
|
| 982 |
+
and return responses.
|
| 983 |
+
|
| 984 |
+
ctx: The virtual engine context to work on
|
| 985 |
+
request_id: If provided, then only this request is going to be processed
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
now = time.time()
|
| 989 |
+
|
| 990 |
+
if len(ctx.output_queue) == 0:
|
| 991 |
+
return None
|
| 992 |
+
|
| 993 |
+
# Get pending async postprocessor
|
| 994 |
+
if request_id:
|
| 995 |
+
# When we process only one request, no pop is required
|
| 996 |
+
# (since later we will process all of the rest)
|
| 997 |
+
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
| 998 |
+
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
|
| 999 |
+
else:
|
| 1000 |
+
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
| 1001 |
+
is_last_step, is_first_step_output,
|
| 1002 |
+
skip) = ctx.output_queue.popleft()
|
| 1003 |
+
|
| 1004 |
+
# Sanity check
|
| 1005 |
+
assert len(seq_group_metadata_list) == len(
|
| 1006 |
+
scheduler_outputs.scheduled_seq_groups)
|
| 1007 |
+
|
| 1008 |
+
has_multiple_outputs: bool = len(outputs) > 1
|
| 1009 |
+
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
|
| 1010 |
+
if has_multiple_outputs:
|
| 1011 |
+
assert self.scheduler_config.is_multi_step or \
|
| 1012 |
+
self.speculative_config
|
| 1013 |
+
# Organize outputs by [step][sequence group] instead of
|
| 1014 |
+
# [sequence group][step].
|
| 1015 |
+
if self.scheduler_config.is_multi_step:
|
| 1016 |
+
outputs_by_sequence_group = create_output_by_sequence_group(
|
| 1017 |
+
outputs, len(seq_group_metadata_list))
|
| 1018 |
+
elif self.speculative_config:
|
| 1019 |
+
# Decodes are multi-steps while prefills are not, outputting at
|
| 1020 |
+
# most 1 token. Separate them so that we can trigger chunk
|
| 1021 |
+
# processing without having to pad or copy over prompts K times
|
| 1022 |
+
# to match decodes structure (costly with prompt_logprobs).
|
| 1023 |
+
num_prefills = sum(sg.is_prompt
|
| 1024 |
+
for sg in seq_group_metadata_list)
|
| 1025 |
+
prefills, decodes = outputs[:num_prefills], outputs[
|
| 1026 |
+
num_prefills:]
|
| 1027 |
+
outputs_by_sequence_group = create_output_by_sequence_group(
|
| 1028 |
+
decodes,
|
| 1029 |
+
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
| 1030 |
+
outputs_by_sequence_group = [p.outputs for p in prefills
|
| 1031 |
+
] + outputs_by_sequence_group
|
| 1032 |
+
# We have outputs for multiple steps submitted in a single burst,
|
| 1033 |
+
# so invalidate is_first_step_output.
|
| 1034 |
+
is_first_step_output = None
|
| 1035 |
+
else:
|
| 1036 |
+
outputs_by_sequence_group = outputs
|
| 1037 |
+
|
| 1038 |
+
# Determine the requests we need to operate on
|
| 1039 |
+
if request_id:
|
| 1040 |
+
indices = []
|
| 1041 |
+
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
| 1042 |
+
if seq_group_meta.request_id == request_id:
|
| 1043 |
+
assert i not in skip # Cannot be called twice
|
| 1044 |
+
indices.append(i)
|
| 1045 |
+
break
|
| 1046 |
+
|
| 1047 |
+
# If the request_id was not found, then it means that
|
| 1048 |
+
# this is a new request that has no pending async
|
| 1049 |
+
# postprocessor
|
| 1050 |
+
if not indices:
|
| 1051 |
+
return
|
| 1052 |
+
else:
|
| 1053 |
+
indices = range(len(seq_group_metadata_list)) # type: ignore
|
| 1054 |
+
|
| 1055 |
+
finished_before: List[int] = []
|
| 1056 |
+
finished_now: List[int] = []
|
| 1057 |
+
for i in indices:
|
| 1058 |
+
if i in skip:
|
| 1059 |
+
continue
|
| 1060 |
+
|
| 1061 |
+
seq_group_meta = seq_group_metadata_list[i]
|
| 1062 |
+
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
| 1063 |
+
|
| 1064 |
+
seq_group: SequenceGroup = scheduled_seq_group.seq_group
|
| 1065 |
+
|
| 1066 |
+
if seq_group.is_finished():
|
| 1067 |
+
finished_before.append(i)
|
| 1068 |
+
continue
|
| 1069 |
+
|
| 1070 |
+
output: List[SequenceGroupOutput]
|
| 1071 |
+
if has_multiple_outputs:
|
| 1072 |
+
output = outputs_by_sequence_group[i]
|
| 1073 |
+
else:
|
| 1074 |
+
output = [outputs_by_sequence_group[0][i]]
|
| 1075 |
+
|
| 1076 |
+
if not is_async:
|
| 1077 |
+
if self.scheduler_config.is_multi_step:
|
| 1078 |
+
# Updates happen only if the sequence is prefill
|
| 1079 |
+
self._update_num_computed_tokens_for_multi_step_prefill(
|
| 1080 |
+
seq_group, seq_group_meta, is_first_step_output)
|
| 1081 |
+
else:
|
| 1082 |
+
seq_group.update_num_computed_tokens(
|
| 1083 |
+
seq_group_meta.token_chunk_size or 0)
|
| 1084 |
+
|
| 1085 |
+
if outputs:
|
| 1086 |
+
for o in outputs:
|
| 1087 |
+
if (isinstance(o, SamplerOutput)
|
| 1088 |
+
and seq_group.metrics is not None):
|
| 1089 |
+
if seq_group.metrics.model_forward_time is not None:
|
| 1090 |
+
seq_group.metrics.model_forward_time += (
|
| 1091 |
+
o.model_forward_time or 0)
|
| 1092 |
+
else:
|
| 1093 |
+
seq_group.metrics.model_forward_time = (
|
| 1094 |
+
o.model_forward_time)
|
| 1095 |
+
if seq_group.metrics.model_execute_time is not None:
|
| 1096 |
+
seq_group.metrics.model_execute_time += (
|
| 1097 |
+
o.model_execute_time or 0)
|
| 1098 |
+
else:
|
| 1099 |
+
seq_group.metrics.model_execute_time = (
|
| 1100 |
+
o.model_execute_time)
|
| 1101 |
+
|
| 1102 |
+
if self.model_config.runner_type == "pooling":
|
| 1103 |
+
self._process_sequence_group_outputs(seq_group, output)
|
| 1104 |
+
else:
|
| 1105 |
+
self.output_processor.process_prompt_logprob(seq_group, output)
|
| 1106 |
+
if seq_group_meta.do_sample:
|
| 1107 |
+
self.output_processor.process_outputs(
|
| 1108 |
+
seq_group, output, is_async)
|
| 1109 |
+
|
| 1110 |
+
if seq_group.is_finished():
|
| 1111 |
+
finished_now.append(i)
|
| 1112 |
+
|
| 1113 |
+
# Generate outputs for the requests that finished this iteration
|
| 1114 |
+
for i in finished_now:
|
| 1115 |
+
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
| 1116 |
+
|
| 1117 |
+
seq_group = scheduled_seq_group.seq_group
|
| 1118 |
+
seq_group.maybe_set_first_token_time(now)
|
| 1119 |
+
if not seq_group.is_prefill():
|
| 1120 |
+
seq_group.set_last_token_time(now)
|
| 1121 |
+
request_output = RequestOutputFactory.create(
|
| 1122 |
+
seq_group,
|
| 1123 |
+
self.seq_id_to_seq_group,
|
| 1124 |
+
use_cache=self.use_cached_outputs)
|
| 1125 |
+
if request_output:
|
| 1126 |
+
ctx.request_outputs.append(request_output)
|
| 1127 |
+
|
| 1128 |
+
# When we process a single request, we skip it for the next time,
|
| 1129 |
+
# and invoke the request output callback (if there was final output)
|
| 1130 |
+
if request_id:
|
| 1131 |
+
assert len(indices) == 1
|
| 1132 |
+
skip.append(indices[0])
|
| 1133 |
+
|
| 1134 |
+
if (finished_now
|
| 1135 |
+
and self.process_request_outputs_callback is not None):
|
| 1136 |
+
self.process_request_outputs_callback(ctx.request_outputs)
|
| 1137 |
+
ctx.request_outputs.clear()
|
| 1138 |
+
return
|
| 1139 |
+
|
| 1140 |
+
# Free currently finished requests
|
| 1141 |
+
if finished_now:
|
| 1142 |
+
for scheduler in self.scheduler:
|
| 1143 |
+
scheduler.free_finished_seq_groups()
|
| 1144 |
+
|
| 1145 |
+
# For multi-step without streaming, don't create outputs each iteration
|
| 1146 |
+
if not is_last_step and not ctx.multi_step_stream_outputs:
|
| 1147 |
+
# Immediately process request outputs here (if callback is given)
|
| 1148 |
+
if (finished_now
|
| 1149 |
+
and self.process_request_outputs_callback is not None):
|
| 1150 |
+
self.process_request_outputs_callback(ctx.request_outputs)
|
| 1151 |
+
ctx.request_outputs.clear()
|
| 1152 |
+
return
|
| 1153 |
+
|
| 1154 |
+
# Create the outputs
|
| 1155 |
+
for i in indices:
|
| 1156 |
+
if i in skip or i in finished_before or i in finished_now:
|
| 1157 |
+
continue # Avoids double processing
|
| 1158 |
+
|
| 1159 |
+
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
| 1160 |
+
|
| 1161 |
+
seq_group = scheduled_seq_group.seq_group
|
| 1162 |
+
seq_group.maybe_set_first_token_time(now)
|
| 1163 |
+
if not seq_group.is_prefill():
|
| 1164 |
+
seq_group.set_last_token_time(now)
|
| 1165 |
+
request_output = RequestOutputFactory.create(
|
| 1166 |
+
seq_group,
|
| 1167 |
+
self.seq_id_to_seq_group,
|
| 1168 |
+
use_cache=self.use_cached_outputs)
|
| 1169 |
+
if request_output:
|
| 1170 |
+
ctx.request_outputs.append(request_output)
|
| 1171 |
+
|
| 1172 |
+
# For multi-step with streaming, create outputs each iteration
|
| 1173 |
+
if not is_last_step and ctx.multi_step_stream_outputs:
|
| 1174 |
+
# Immediately process request outputs here (if callback is given)
|
| 1175 |
+
if self.process_request_outputs_callback is not None:
|
| 1176 |
+
self.process_request_outputs_callback(ctx.request_outputs)
|
| 1177 |
+
ctx.request_outputs.clear()
|
| 1178 |
+
return
|
| 1179 |
+
|
| 1180 |
+
for seq_group in scheduler_outputs.ignored_seq_groups:
|
| 1181 |
+
params = seq_group.sampling_params
|
| 1182 |
+
if params is not None and params.output_kind == (
|
| 1183 |
+
RequestOutputKind.DELTA) and not seq_group.is_finished():
|
| 1184 |
+
continue
|
| 1185 |
+
|
| 1186 |
+
request_output = RequestOutputFactory.create(
|
| 1187 |
+
seq_group,
|
| 1188 |
+
self.seq_id_to_seq_group,
|
| 1189 |
+
use_cache=self.use_cached_outputs,
|
| 1190 |
+
)
|
| 1191 |
+
if request_output:
|
| 1192 |
+
ctx.request_outputs.append(request_output)
|
| 1193 |
+
|
| 1194 |
+
# Immediately process request outputs here (if callback is given)
|
| 1195 |
+
if (ctx.request_outputs
|
| 1196 |
+
and self.process_request_outputs_callback is not None):
|
| 1197 |
+
self.process_request_outputs_callback(ctx.request_outputs)
|
| 1198 |
+
ctx.request_outputs.clear()
|
| 1199 |
+
|
| 1200 |
+
# For async case, we need to record the stats here.
|
| 1201 |
+
# For non-async case, the stats are done in the
|
| 1202 |
+
# LLMEngine/AsyncLLMEngine directly
|
| 1203 |
+
if is_async:
|
| 1204 |
+
# Log stats.
|
| 1205 |
+
self.do_log_stats(scheduler_outputs, outputs, finished_before,
|
| 1206 |
+
skip)
|
| 1207 |
+
|
| 1208 |
+
# Tracing
|
| 1209 |
+
self.do_tracing(scheduler_outputs, finished_before)
|
| 1210 |
+
|
| 1211 |
+
return None
|
| 1212 |
+
|
| 1213 |
+
def _advance_to_next_step(
|
| 1214 |
+
self, output: List[SamplerOutput],
|
| 1215 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 1216 |
+
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
|
| 1217 |
+
"""Given model output from a single run, append the tokens to the
|
| 1218 |
+
sequences. This is normally done inside output processor, but it is
|
| 1219 |
+
required if the worker is to perform async forward pass to next step.
|
| 1220 |
+
"""
|
| 1221 |
+
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
|
| 1222 |
+
zip(seq_group_metadata_list, output, scheduled_seq_groups):
|
| 1223 |
+
seq_group = scheduled_seq_group.seq_group
|
| 1224 |
+
|
| 1225 |
+
if seq_group.is_finished():
|
| 1226 |
+
continue
|
| 1227 |
+
|
| 1228 |
+
if self.scheduler_config.is_multi_step:
|
| 1229 |
+
# Updates happen only if the sequence is prefill
|
| 1230 |
+
self._update_num_computed_tokens_for_multi_step_prefill(
|
| 1231 |
+
seq_group, seq_group_metadata,
|
| 1232 |
+
seq_group.state.num_steps == 1)
|
| 1233 |
+
else:
|
| 1234 |
+
token_chunk_size = (seq_group_metadata.token_chunk_size
|
| 1235 |
+
if seq_group_metadata.token_chunk_size
|
| 1236 |
+
is not None else 0)
|
| 1237 |
+
seq_group.update_num_computed_tokens(token_chunk_size)
|
| 1238 |
+
|
| 1239 |
+
if seq_group_metadata.do_sample:
|
| 1240 |
+
assert len(sequence_group_outputs.samples) == 1, (
|
| 1241 |
+
"Async output processor expects a single sample"
|
| 1242 |
+
" (i.e sampling_params.n == 1)")
|
| 1243 |
+
sample = sequence_group_outputs.samples[0]
|
| 1244 |
+
|
| 1245 |
+
assert len(seq_group.seqs) == 1
|
| 1246 |
+
seq = seq_group.seqs[0]
|
| 1247 |
+
|
| 1248 |
+
if self.scheduler_config.is_multi_step:
|
| 1249 |
+
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
| 1250 |
+
) == 0
|
| 1251 |
+
seq.append_token_id(sample.output_token, sample.logprobs)
|
| 1252 |
+
if not is_prefill_append:
|
| 1253 |
+
seq_group.update_num_computed_tokens(1)
|
| 1254 |
+
else:
|
| 1255 |
+
seq.append_token_id(sample.output_token, sample.logprobs)
|
| 1256 |
+
|
| 1257 |
+
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
| 1258 |
+
"""Performs one decoding iteration and returns newly generated results.
|
| 1259 |
+
|
| 1260 |
+
.. figure:: https://i.imgur.com/sv2HssD.png
|
| 1261 |
+
:alt: Overview of the step function
|
| 1262 |
+
:align: center
|
| 1263 |
+
|
| 1264 |
+
Overview of the step function.
|
| 1265 |
+
|
| 1266 |
+
Details:
|
| 1267 |
+
- Step 1: Schedules the sequences to be executed in the next
|
| 1268 |
+
iteration and the token blocks to be swapped in/out/copy.
|
| 1269 |
+
|
| 1270 |
+
- Depending on the scheduling policy,
|
| 1271 |
+
sequences may be `preempted/reordered`.
|
| 1272 |
+
- A Sequence Group (SG) refer to a group of sequences
|
| 1273 |
+
that are generated from the same prompt.
|
| 1274 |
+
|
| 1275 |
+
- Step 2: Calls the distributed executor to execute the model.
|
| 1276 |
+
- Step 3: Processes the model output. This mainly includes:
|
| 1277 |
+
|
| 1278 |
+
- Decodes the relevant outputs.
|
| 1279 |
+
- Updates the scheduled sequence groups with model outputs
|
| 1280 |
+
based on its `sampling parameters` (`use_beam_search` or not).
|
| 1281 |
+
- Frees the finished sequence groups.
|
| 1282 |
+
|
| 1283 |
+
- Finally, it creates and returns the newly generated results.
|
| 1284 |
+
|
| 1285 |
+
Example:
|
| 1286 |
+
>>> # Please see the example/ folder for more detailed examples.
|
| 1287 |
+
>>>
|
| 1288 |
+
>>> # initialize engine and request arguments
|
| 1289 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
| 1290 |
+
>>> example_inputs = [(0, "What is LLM?",
|
| 1291 |
+
>>> SamplingParams(temperature=0.0))]
|
| 1292 |
+
>>>
|
| 1293 |
+
>>> # Start the engine with an event loop
|
| 1294 |
+
>>> while True:
|
| 1295 |
+
>>> if example_inputs:
|
| 1296 |
+
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
| 1297 |
+
>>> engine.add_request(str(req_id),prompt,sampling_params)
|
| 1298 |
+
>>>
|
| 1299 |
+
>>> # continue the request processing
|
| 1300 |
+
>>> request_outputs = engine.step()
|
| 1301 |
+
>>> for request_output in request_outputs:
|
| 1302 |
+
>>> if request_output.finished:
|
| 1303 |
+
>>> # return or show the request output
|
| 1304 |
+
>>>
|
| 1305 |
+
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
| 1306 |
+
>>> break
|
| 1307 |
+
"""
|
| 1308 |
+
if self.parallel_config.pipeline_parallel_size > 1:
|
| 1309 |
+
raise NotImplementedError(
|
| 1310 |
+
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
| 1311 |
+
"as performance will be severely degraded otherwise.")
|
| 1312 |
+
|
| 1313 |
+
# For llm_engine, there is no pipeline parallel support, so the engine
|
| 1314 |
+
# used is always 0.
|
| 1315 |
+
virtual_engine = 0
|
| 1316 |
+
|
| 1317 |
+
# These are cached outputs from previous iterations. None if on first
|
| 1318 |
+
# iteration
|
| 1319 |
+
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
| 1320 |
+
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
| 1321 |
+
scheduler_outputs = cached_outputs.scheduler_outputs
|
| 1322 |
+
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
| 1323 |
+
|
| 1324 |
+
ctx = self.scheduler_contexts[virtual_engine]
|
| 1325 |
+
|
| 1326 |
+
# Clear outputs for each new scheduler iteration
|
| 1327 |
+
ctx.request_outputs.clear()
|
| 1328 |
+
|
| 1329 |
+
# Skip the scheduler if there are any remaining steps in the seq groups.
|
| 1330 |
+
# This ensures that the scheduler is only called again when the current
|
| 1331 |
+
# batch has completed.
|
| 1332 |
+
if not self._has_remaining_steps(seq_group_metadata_list):
|
| 1333 |
+
# Schedule iteration
|
| 1334 |
+
(seq_group_metadata_list, scheduler_outputs,
|
| 1335 |
+
allow_async_output_proc
|
| 1336 |
+
) = self.scheduler[virtual_engine].schedule()
|
| 1337 |
+
|
| 1338 |
+
ctx.seq_group_metadata_list = seq_group_metadata_list
|
| 1339 |
+
ctx.scheduler_outputs = scheduler_outputs
|
| 1340 |
+
|
| 1341 |
+
finished_requests_ids = self.scheduler[
|
| 1342 |
+
virtual_engine].get_and_reset_finished_requests_ids()
|
| 1343 |
+
|
| 1344 |
+
# Maybe switch from async mode to sync mode
|
| 1345 |
+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
| 1346 |
+
self._process_model_outputs(ctx=ctx)
|
| 1347 |
+
|
| 1348 |
+
if (self.scheduler_config.is_multi_step
|
| 1349 |
+
and scheduler_outputs.num_lookahead_slots > 0):
|
| 1350 |
+
# cache the scheduler outputs for the next iteration if we have
|
| 1351 |
+
# lookahead slots
|
| 1352 |
+
self._cache_scheduler_outputs_for_multi_step(
|
| 1353 |
+
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
| 1354 |
+
allow_async_output_proc)
|
| 1355 |
+
else:
|
| 1356 |
+
finished_requests_ids = list()
|
| 1357 |
+
|
| 1358 |
+
assert seq_group_metadata_list is not None
|
| 1359 |
+
assert scheduler_outputs is not None
|
| 1360 |
+
|
| 1361 |
+
if not scheduler_outputs.is_empty():
|
| 1362 |
+
|
| 1363 |
+
# Check if we have a cached last_output from the previous iteration.
|
| 1364 |
+
# For supporting PP this is probably the best way to pass the
|
| 1365 |
+
# sampled_token_ids, as a separate broadcast over all the PP stages
|
| 1366 |
+
# will cause one virtual engine's microbatch to block the pipeline.
|
| 1367 |
+
last_sampled_token_ids = \
|
| 1368 |
+
self._get_last_sampled_token_ids(virtual_engine)
|
| 1369 |
+
|
| 1370 |
+
execute_model_req = ExecuteModelRequest(
|
| 1371 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 1372 |
+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
| 1373 |
+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
| 1374 |
+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
| 1375 |
+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
| 1376 |
+
running_queue_size=scheduler_outputs.running_queue_size,
|
| 1377 |
+
finished_requests_ids=finished_requests_ids,
|
| 1378 |
+
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
| 1379 |
+
# to each of the non-last PP stages for in-place prepare_input.
|
| 1380 |
+
last_sampled_token_ids=last_sampled_token_ids)
|
| 1381 |
+
|
| 1382 |
+
if allow_async_output_proc:
|
| 1383 |
+
execute_model_req.async_callback = self.async_callbacks[
|
| 1384 |
+
virtual_engine]
|
| 1385 |
+
|
| 1386 |
+
outputs = self.model_executor.execute_model(
|
| 1387 |
+
execute_model_req=execute_model_req)
|
| 1388 |
+
|
| 1389 |
+
# We need to do this here so that last step's sampled_token_ids can
|
| 1390 |
+
# be passed to the next iteration for PP.
|
| 1391 |
+
if self.scheduler_config.is_multi_step:
|
| 1392 |
+
self._update_cached_scheduler_output(virtual_engine, outputs)
|
| 1393 |
+
else:
|
| 1394 |
+
# Nothing scheduled => If there is pending async postprocessor,
|
| 1395 |
+
# then finish it here.
|
| 1396 |
+
if len(ctx.output_queue) > 0:
|
| 1397 |
+
self._process_model_outputs(ctx=ctx)
|
| 1398 |
+
# No outputs in this case
|
| 1399 |
+
outputs = []
|
| 1400 |
+
|
| 1401 |
+
# Finish the current step for all the sequence groups.
|
| 1402 |
+
if self.scheduler_config.is_multi_step:
|
| 1403 |
+
for seq_group in seq_group_metadata_list:
|
| 1404 |
+
seq_group.finish_step()
|
| 1405 |
+
|
| 1406 |
+
if not self._has_remaining_steps(seq_group_metadata_list):
|
| 1407 |
+
# clear the cache if we have finished all the steps.
|
| 1408 |
+
if self.scheduler_config.is_multi_step:
|
| 1409 |
+
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
| 1410 |
+
|
| 1411 |
+
# is_first_step_output is True only when the num_steps of all
|
| 1412 |
+
# the sequences are 1. When the num_steps > 1,
|
| 1413 |
+
# multi_step_model_runner does the first-step output append.
|
| 1414 |
+
is_first_step_output: bool = False if not seq_group_metadata_list \
|
| 1415 |
+
else seq_group_metadata_list[0].state.num_steps == 1
|
| 1416 |
+
|
| 1417 |
+
# Add results to the output_queue
|
| 1418 |
+
ctx.append_output(outputs=outputs,
|
| 1419 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 1420 |
+
scheduler_outputs=scheduler_outputs,
|
| 1421 |
+
is_async=allow_async_output_proc,
|
| 1422 |
+
is_last_step=True,
|
| 1423 |
+
is_first_step_output=is_first_step_output)
|
| 1424 |
+
|
| 1425 |
+
if outputs and allow_async_output_proc:
|
| 1426 |
+
assert len(outputs) == 1, (
|
| 1427 |
+
"Async postprocessor expects only a single output set")
|
| 1428 |
+
|
| 1429 |
+
self._advance_to_next_step(
|
| 1430 |
+
outputs[0], seq_group_metadata_list,
|
| 1431 |
+
scheduler_outputs.scheduled_seq_groups)
|
| 1432 |
+
|
| 1433 |
+
# Check if need to run the usual non-async path
|
| 1434 |
+
if not allow_async_output_proc:
|
| 1435 |
+
self._process_model_outputs(ctx=ctx)
|
| 1436 |
+
|
| 1437 |
+
# Log stats.
|
| 1438 |
+
self.do_log_stats(scheduler_outputs, outputs)
|
| 1439 |
+
|
| 1440 |
+
# Tracing
|
| 1441 |
+
self.do_tracing(scheduler_outputs)
|
| 1442 |
+
else:
|
| 1443 |
+
# Multi-step case
|
| 1444 |
+
return ctx.request_outputs
|
| 1445 |
+
|
| 1446 |
+
if not self.has_unfinished_requests():
|
| 1447 |
+
# Drain async postprocessor (if exists)
|
| 1448 |
+
if len(ctx.output_queue) > 0:
|
| 1449 |
+
self._process_model_outputs(ctx=ctx)
|
| 1450 |
+
assert len(ctx.output_queue) == 0
|
| 1451 |
+
|
| 1452 |
+
# Stop the execute model loop in parallel workers until there are
|
| 1453 |
+
# more requests to process. This avoids waiting indefinitely in
|
| 1454 |
+
# torch.distributed ops which may otherwise timeout, and unblocks
|
| 1455 |
+
# the RPC thread in the workers so that they can process any other
|
| 1456 |
+
# queued control plane messages, such as add/remove lora adapters.
|
| 1457 |
+
logger.debug("Stopping remote worker execution loop.")
|
| 1458 |
+
self.model_executor.stop_remote_worker_execution_loop()
|
| 1459 |
+
|
| 1460 |
+
return ctx.request_outputs
|
| 1461 |
+
|
| 1462 |
+
def _has_remaining_steps(
|
| 1463 |
+
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
| 1464 |
+
) -> bool:
|
| 1465 |
+
if (not self.scheduler_config.is_multi_step
|
| 1466 |
+
or not seq_group_metadata_list):
|
| 1467 |
+
return False
|
| 1468 |
+
|
| 1469 |
+
# TODO(will) this is a sanity check for nowto make sure that all the
|
| 1470 |
+
# seqs are on the same steps. Eventually we will want to do some sort of
|
| 1471 |
+
# dynamic scheduling when doing multi-step decoding.
|
| 1472 |
+
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
| 1473 |
+
if any([
|
| 1474 |
+
seq_group.state.remaining_steps != ref_remaining_steps
|
| 1475 |
+
for seq_group in seq_group_metadata_list[1:]
|
| 1476 |
+
]):
|
| 1477 |
+
raise AssertionError("All running sequence groups should "
|
| 1478 |
+
"have the same remaining steps.")
|
| 1479 |
+
|
| 1480 |
+
return ref_remaining_steps > 0
|
| 1481 |
+
|
| 1482 |
+
def _cache_scheduler_outputs_for_multi_step(
|
| 1483 |
+
self, virtual_engine: int,
|
| 1484 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
| 1485 |
+
scheduler_outputs: SchedulerOutputs,
|
| 1486 |
+
allow_async_output_proc: bool) -> None:
|
| 1487 |
+
co = self.cached_scheduler_outputs[virtual_engine]
|
| 1488 |
+
|
| 1489 |
+
co.seq_group_metadata_list = seq_group_metadata_list
|
| 1490 |
+
co.scheduler_outputs = scheduler_outputs
|
| 1491 |
+
co.allow_async_output_proc = allow_async_output_proc
|
| 1492 |
+
co.last_output = None
|
| 1493 |
+
|
| 1494 |
+
def _update_cached_scheduler_output(
|
| 1495 |
+
self, virtual_engine: int,
|
| 1496 |
+
output: List[Optional[SamplerOutput]]) -> None:
|
| 1497 |
+
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
| 1498 |
+
and output[0] is not None):
|
| 1499 |
+
last_output = output[-1]
|
| 1500 |
+
assert last_output is not None
|
| 1501 |
+
assert last_output.sampled_token_ids_cpu is not None
|
| 1502 |
+
assert last_output.sampled_token_ids is None
|
| 1503 |
+
assert last_output.sampled_token_probs is None
|
| 1504 |
+
self.cached_scheduler_outputs[
|
| 1505 |
+
virtual_engine].last_output = last_output
|
| 1506 |
+
|
| 1507 |
+
def _get_last_sampled_token_ids(
|
| 1508 |
+
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
| 1509 |
+
cached_last_output = self.cached_scheduler_outputs[
|
| 1510 |
+
virtual_engine].last_output
|
| 1511 |
+
if (self.scheduler_config.is_multi_step
|
| 1512 |
+
and self.parallel_config.pipeline_parallel_size > 1
|
| 1513 |
+
and cached_last_output is not None
|
| 1514 |
+
and cached_last_output.sampled_token_ids_cpu is not None):
|
| 1515 |
+
return cached_last_output.sampled_token_ids_cpu
|
| 1516 |
+
return None
|
| 1517 |
+
|
| 1518 |
+
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
| 1519 |
+
if not self.log_stats:
|
| 1520 |
+
raise RuntimeError(
|
| 1521 |
+
"Stat logging is disabled. Set `disable_log_stats=False` "
|
| 1522 |
+
"argument to enable.")
|
| 1523 |
+
if logger_name in self.stat_loggers:
|
| 1524 |
+
raise KeyError(f"Logger with name {logger_name} already exists.")
|
| 1525 |
+
self.stat_loggers[logger_name] = logger
|
| 1526 |
+
|
| 1527 |
+
def remove_logger(self, logger_name: str) -> None:
|
| 1528 |
+
if not self.log_stats:
|
| 1529 |
+
raise RuntimeError(
|
| 1530 |
+
"Stat logging is disabled. Set `disable_log_stats=False` "
|
| 1531 |
+
"argument to enable.")
|
| 1532 |
+
if logger_name not in self.stat_loggers:
|
| 1533 |
+
raise KeyError(f"Logger with name {logger_name} does not exist.")
|
| 1534 |
+
del self.stat_loggers[logger_name]
|
| 1535 |
+
|
| 1536 |
+
def do_log_stats(self,
|
| 1537 |
+
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
| 1538 |
+
model_output: Optional[List[SamplerOutput]] = None,
|
| 1539 |
+
finished_before: Optional[List[int]] = None,
|
| 1540 |
+
skip: Optional[List[int]] = None) -> None:
|
| 1541 |
+
"""Forced log when no requests active."""
|
| 1542 |
+
if self.log_stats:
|
| 1543 |
+
stats = self._get_stats(scheduler_outputs, model_output,
|
| 1544 |
+
finished_before, skip)
|
| 1545 |
+
for logger in self.stat_loggers.values():
|
| 1546 |
+
logger.log(stats)
|
| 1547 |
+
|
| 1548 |
+
def _get_stats(self,
|
| 1549 |
+
scheduler_outputs: Optional[SchedulerOutputs],
|
| 1550 |
+
model_output: Optional[List[SamplerOutput]] = None,
|
| 1551 |
+
finished_before: Optional[List[int]] = None,
|
| 1552 |
+
skip: Optional[List[int]] = None) -> Stats:
|
| 1553 |
+
"""Get Stats to be Logged to Prometheus.
|
| 1554 |
+
|
| 1555 |
+
Args:
|
| 1556 |
+
scheduler_outputs: Optional, used to populate metrics related to
|
| 1557 |
+
the scheduled batch,
|
| 1558 |
+
model_output: Optional, used to emit speculative decoding metrics
|
| 1559 |
+
which are created by the workers.
|
| 1560 |
+
finished_before: Optional, indices of sequences that were finished
|
| 1561 |
+
before. These sequences will be ignored.
|
| 1562 |
+
skip: Optional, indices of sequences that were preempted. These
|
| 1563 |
+
sequences will be ignored.
|
| 1564 |
+
"""
|
| 1565 |
+
now = time.time()
|
| 1566 |
+
|
| 1567 |
+
# System State
|
| 1568 |
+
# Scheduler State
|
| 1569 |
+
num_running_sys = sum(
|
| 1570 |
+
len(scheduler.running) for scheduler in self.scheduler)
|
| 1571 |
+
num_swapped_sys = sum(
|
| 1572 |
+
len(scheduler.swapped) for scheduler in self.scheduler)
|
| 1573 |
+
num_waiting_sys = sum(
|
| 1574 |
+
len(scheduler.waiting) for scheduler in self.scheduler)
|
| 1575 |
+
|
| 1576 |
+
# KV Cache Usage in %
|
| 1577 |
+
num_total_gpu = self.cache_config.num_gpu_blocks
|
| 1578 |
+
gpu_cache_usage_sys = 0.
|
| 1579 |
+
if num_total_gpu: # Guard against both None and 0
|
| 1580 |
+
num_free_gpu = sum(
|
| 1581 |
+
scheduler.block_manager.get_num_free_gpu_blocks()
|
| 1582 |
+
for scheduler in self.scheduler)
|
| 1583 |
+
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
| 1584 |
+
|
| 1585 |
+
num_total_cpu = self.cache_config.num_cpu_blocks
|
| 1586 |
+
cpu_cache_usage_sys = 0.
|
| 1587 |
+
if num_total_cpu: # Guard against both None and 0
|
| 1588 |
+
num_free_cpu = sum(
|
| 1589 |
+
scheduler.block_manager.get_num_free_cpu_blocks()
|
| 1590 |
+
for scheduler in self.scheduler)
|
| 1591 |
+
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
| 1592 |
+
|
| 1593 |
+
# Prefix Cache Hit Rate. Note that we always use
|
| 1594 |
+
# the cache hit rate of the first virtual engine.
|
| 1595 |
+
cpu_prefix_cache_hit_rate = self.scheduler[
|
| 1596 |
+
0].get_prefix_cache_hit_rate(Device.CPU)
|
| 1597 |
+
gpu_prefix_cache_hit_rate = self.scheduler[
|
| 1598 |
+
0].get_prefix_cache_hit_rate(Device.GPU)
|
| 1599 |
+
|
| 1600 |
+
# Iteration stats
|
| 1601 |
+
num_prompt_tokens_iter = 0
|
| 1602 |
+
num_generation_tokens_iter = 0
|
| 1603 |
+
num_tokens_iter = 0
|
| 1604 |
+
time_to_first_tokens_iter: List[float] = []
|
| 1605 |
+
time_per_output_tokens_iter: List[float] = []
|
| 1606 |
+
num_preemption_iter = (0 if scheduler_outputs is None else
|
| 1607 |
+
scheduler_outputs.preempted)
|
| 1608 |
+
|
| 1609 |
+
# Request stats
|
| 1610 |
+
# Latency
|
| 1611 |
+
time_e2e_requests: List[float] = []
|
| 1612 |
+
time_queue_requests: List[float] = []
|
| 1613 |
+
time_inference_requests: List[float] = []
|
| 1614 |
+
time_prefill_requests: List[float] = []
|
| 1615 |
+
time_decode_requests: List[float] = []
|
| 1616 |
+
time_in_queue_requests: List[float] = []
|
| 1617 |
+
model_forward_time_requests: List[float] = []
|
| 1618 |
+
model_execute_time_requests: List[float] = []
|
| 1619 |
+
# Metadata
|
| 1620 |
+
num_prompt_tokens_requests: List[int] = []
|
| 1621 |
+
num_generation_tokens_requests: List[int] = []
|
| 1622 |
+
n_requests: List[int] = []
|
| 1623 |
+
max_num_generation_tokens_requests: List[int] = []
|
| 1624 |
+
max_tokens_requests: List[int] = []
|
| 1625 |
+
finished_reason_requests: List[str] = []
|
| 1626 |
+
|
| 1627 |
+
# Lora requests
|
| 1628 |
+
running_lora_adapters = dict(
|
| 1629 |
+
collectionsCounter([
|
| 1630 |
+
running_request.lora_request.lora_name
|
| 1631 |
+
for scheduler in self.scheduler
|
| 1632 |
+
for running_request in scheduler.running
|
| 1633 |
+
if running_request.lora_request
|
| 1634 |
+
]))
|
| 1635 |
+
waiting_lora_adapters = dict(
|
| 1636 |
+
collectionsCounter([
|
| 1637 |
+
waiting_request.lora_request.lora_name
|
| 1638 |
+
for scheduler in self.scheduler
|
| 1639 |
+
for waiting_request in scheduler.waiting
|
| 1640 |
+
if waiting_request.lora_request
|
| 1641 |
+
]))
|
| 1642 |
+
max_lora_stat = "0"
|
| 1643 |
+
if self.lora_config:
|
| 1644 |
+
max_lora_stat = str(self.lora_config.max_loras)
|
| 1645 |
+
|
| 1646 |
+
# NOTE: This loop assumes prefill seq_groups are before
|
| 1647 |
+
# decode seq_groups in scheduled_seq_groups.
|
| 1648 |
+
if scheduler_outputs is not None:
|
| 1649 |
+
# For async postprocessor, already finished sequences need to be
|
| 1650 |
+
# not counted (to avoid double counting)
|
| 1651 |
+
actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
|
| 1652 |
+
|
| 1653 |
+
num_generation_tokens_from_prefill_groups = 0
|
| 1654 |
+
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
|
| 1655 |
+
# the len of scheduler_outputs.scheduled_seq_groups is !=
|
| 1656 |
+
# scheduler_outputs.num_prefill_groups, this means that
|
| 1657 |
+
# chunked prefills have been detected.
|
| 1658 |
+
|
| 1659 |
+
for idx, scheduled_seq_group in enumerate(
|
| 1660 |
+
scheduler_outputs.scheduled_seq_groups):
|
| 1661 |
+
# Skip double logging when using async output proc
|
| 1662 |
+
if finished_before and idx in finished_before:
|
| 1663 |
+
actual_num_batched_tokens -= 1
|
| 1664 |
+
continue
|
| 1665 |
+
|
| 1666 |
+
# Currently, skip == preempted sequences, so we need to skip
|
| 1667 |
+
# their log stats
|
| 1668 |
+
if skip and idx in skip:
|
| 1669 |
+
continue
|
| 1670 |
+
|
| 1671 |
+
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
|
| 1672 |
+
seq_group = scheduled_seq_group.seq_group
|
| 1673 |
+
|
| 1674 |
+
# NOTE: a seq_group that completed all of its prefill tokens
|
| 1675 |
+
# in the last iteration will have seq_group.is_prefill() = False
|
| 1676 |
+
# with group_was_prefill = True
|
| 1677 |
+
if group_was_prefill:
|
| 1678 |
+
# Number of prompt tokens.
|
| 1679 |
+
num_prompt_tokens_iter += (
|
| 1680 |
+
scheduled_seq_group.token_chunk_size)
|
| 1681 |
+
|
| 1682 |
+
# If the seq_group just finished the prefill state
|
| 1683 |
+
# get TTFT.
|
| 1684 |
+
if not seq_group.is_prefill():
|
| 1685 |
+
latency = seq_group.get_last_token_latency()
|
| 1686 |
+
time_to_first_tokens_iter.append(latency)
|
| 1687 |
+
|
| 1688 |
+
# One generation token per finished prefill.
|
| 1689 |
+
num_generation_tokens_from_prefill_groups += (
|
| 1690 |
+
seq_group.num_seqs())
|
| 1691 |
+
else:
|
| 1692 |
+
# TPOTs.
|
| 1693 |
+
latency = seq_group.get_last_token_latency()
|
| 1694 |
+
time_per_output_tokens_iter.append(latency)
|
| 1695 |
+
if seq_group.state.current_step == 0:
|
| 1696 |
+
# For async_output_proc, the do_log_stats()
|
| 1697 |
+
# is called following init_multi_step(), which
|
| 1698 |
+
# sets the current_step to zero.
|
| 1699 |
+
actual_num_batched_tokens +=\
|
| 1700 |
+
seq_group.state.num_steps - 1
|
| 1701 |
+
else:
|
| 1702 |
+
actual_num_batched_tokens +=\
|
| 1703 |
+
seq_group.state.current_step - 1
|
| 1704 |
+
|
| 1705 |
+
# Because of chunked prefill, we can have a single sequence
|
| 1706 |
+
# group that does multiple prompt_runs. To prevent logging
|
| 1707 |
+
# the same metadata more than once per request, we standardize
|
| 1708 |
+
# on logging request level information for finished requests,
|
| 1709 |
+
# which can only happen once.
|
| 1710 |
+
if seq_group.is_finished():
|
| 1711 |
+
# Latency timings
|
| 1712 |
+
time_e2e_requests.append(now -
|
| 1713 |
+
seq_group.metrics.arrival_time)
|
| 1714 |
+
if (seq_group.metrics.first_scheduled_time is not None and
|
| 1715 |
+
seq_group.metrics.first_token_time is not None):
|
| 1716 |
+
time_queue_requests.append(
|
| 1717 |
+
seq_group.metrics.first_scheduled_time -
|
| 1718 |
+
seq_group.metrics.arrival_time)
|
| 1719 |
+
time_prefill_requests.append(
|
| 1720 |
+
seq_group.metrics.first_token_time -
|
| 1721 |
+
seq_group.metrics.first_scheduled_time)
|
| 1722 |
+
time_decode_requests.append(
|
| 1723 |
+
now - seq_group.metrics.first_token_time)
|
| 1724 |
+
time_inference_requests.append(
|
| 1725 |
+
now - seq_group.metrics.first_scheduled_time)
|
| 1726 |
+
if seq_group.metrics.time_in_queue is not None:
|
| 1727 |
+
time_in_queue_requests.append(
|
| 1728 |
+
seq_group.metrics.time_in_queue)
|
| 1729 |
+
if seq_group.metrics.model_forward_time is not None:
|
| 1730 |
+
model_forward_time_requests.append(
|
| 1731 |
+
seq_group.metrics.model_forward_time)
|
| 1732 |
+
if seq_group.metrics.model_execute_time is not None:
|
| 1733 |
+
model_execute_time_requests.append(
|
| 1734 |
+
seq_group.metrics.model_execute_time * 1000)
|
| 1735 |
+
# Metadata
|
| 1736 |
+
num_prompt_tokens_requests.append(
|
| 1737 |
+
len(seq_group.prompt_token_ids))
|
| 1738 |
+
num_generation_tokens_requests.extend([
|
| 1739 |
+
seq.get_output_len()
|
| 1740 |
+
for seq in seq_group.get_finished_seqs()
|
| 1741 |
+
])
|
| 1742 |
+
max_num_generation_tokens_requests.append(
|
| 1743 |
+
max(seq.get_output_len()
|
| 1744 |
+
for seq in seq_group.get_seqs()))
|
| 1745 |
+
if seq_group.sampling_params is not None:
|
| 1746 |
+
n_requests.append(seq_group.sampling_params.n)
|
| 1747 |
+
max_tokens_requests.append(
|
| 1748 |
+
seq_group.sampling_params.max_tokens)
|
| 1749 |
+
finished_reason_requests.extend([
|
| 1750 |
+
SequenceStatus.get_finished_reason(seq.status)
|
| 1751 |
+
for seq in seq_group.get_finished_seqs()
|
| 1752 |
+
])
|
| 1753 |
+
|
| 1754 |
+
# Number of generation tokens.
|
| 1755 |
+
# num_batched_tokens equals the number of prompt_tokens plus the
|
| 1756 |
+
# number of decode_tokens in a single iteration. So,
|
| 1757 |
+
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
|
| 1758 |
+
# + num_generation_tokens_from_prefill_groups (since we generate
|
| 1759 |
+
# one token on prefills on iters where the prefill finishes).
|
| 1760 |
+
num_generation_tokens_iter = (
|
| 1761 |
+
actual_num_batched_tokens - num_prompt_tokens_iter +
|
| 1762 |
+
num_generation_tokens_from_prefill_groups)
|
| 1763 |
+
num_tokens_iter = (num_generation_tokens_iter +
|
| 1764 |
+
num_prompt_tokens_iter)
|
| 1765 |
+
# Spec decode, if enabled, emits specialized metrics from the worker in
|
| 1766 |
+
# sampler output.
|
| 1767 |
+
if model_output and isinstance(model_output[0], SamplerOutput) and (
|
| 1768 |
+
model_output[0].spec_decode_worker_metrics is not None):
|
| 1769 |
+
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
| 1770 |
+
else:
|
| 1771 |
+
spec_decode_metrics = None
|
| 1772 |
+
|
| 1773 |
+
return Stats(
|
| 1774 |
+
now=now,
|
| 1775 |
+
# System stats
|
| 1776 |
+
# Scheduler State
|
| 1777 |
+
num_running_sys=num_running_sys,
|
| 1778 |
+
num_swapped_sys=num_swapped_sys,
|
| 1779 |
+
num_waiting_sys=num_waiting_sys,
|
| 1780 |
+
# KV Cache Usage in %
|
| 1781 |
+
gpu_cache_usage_sys=gpu_cache_usage_sys,
|
| 1782 |
+
cpu_cache_usage_sys=cpu_cache_usage_sys,
|
| 1783 |
+
# Prefix Cache Hit Rate
|
| 1784 |
+
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
|
| 1785 |
+
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
|
| 1786 |
+
|
| 1787 |
+
# Iteration stats
|
| 1788 |
+
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
| 1789 |
+
num_generation_tokens_iter=num_generation_tokens_iter,
|
| 1790 |
+
num_tokens_iter=num_tokens_iter,
|
| 1791 |
+
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
| 1792 |
+
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
| 1793 |
+
spec_decode_metrics=spec_decode_metrics,
|
| 1794 |
+
num_preemption_iter=num_preemption_iter,
|
| 1795 |
+
|
| 1796 |
+
# Request stats
|
| 1797 |
+
# Latency
|
| 1798 |
+
time_e2e_requests=time_e2e_requests,
|
| 1799 |
+
time_queue_requests=time_queue_requests,
|
| 1800 |
+
time_inference_requests=time_inference_requests,
|
| 1801 |
+
time_prefill_requests=time_prefill_requests,
|
| 1802 |
+
time_decode_requests=time_decode_requests,
|
| 1803 |
+
time_in_queue_requests=time_in_queue_requests,
|
| 1804 |
+
model_forward_time_requests=model_forward_time_requests,
|
| 1805 |
+
model_execute_time_requests=model_execute_time_requests,
|
| 1806 |
+
# Metadata
|
| 1807 |
+
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
| 1808 |
+
num_generation_tokens_requests=num_generation_tokens_requests,
|
| 1809 |
+
max_num_generation_tokens_requests=
|
| 1810 |
+
max_num_generation_tokens_requests,
|
| 1811 |
+
n_requests=n_requests,
|
| 1812 |
+
max_tokens_requests=max_tokens_requests,
|
| 1813 |
+
finished_reason_requests=finished_reason_requests,
|
| 1814 |
+
max_lora=str(max_lora_stat),
|
| 1815 |
+
waiting_lora_adapters=list(waiting_lora_adapters.keys()),
|
| 1816 |
+
running_lora_adapters=list(running_lora_adapters.keys()))
|
| 1817 |
+
|
| 1818 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 1819 |
+
return self.model_executor.add_lora(lora_request)
|
| 1820 |
+
|
| 1821 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 1822 |
+
return self.model_executor.remove_lora(lora_id)
|
| 1823 |
+
|
| 1824 |
+
def list_loras(self) -> Set[int]:
|
| 1825 |
+
return self.model_executor.list_loras()
|
| 1826 |
+
|
| 1827 |
+
def pin_lora(self, lora_id: int) -> bool:
|
| 1828 |
+
return self.model_executor.pin_lora(lora_id)
|
| 1829 |
+
|
| 1830 |
+
def add_prompt_adapter(
|
| 1831 |
+
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
| 1832 |
+
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
|
| 1833 |
+
|
| 1834 |
+
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
| 1835 |
+
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
|
| 1836 |
+
|
| 1837 |
+
def list_prompt_adapters(self) -> List[int]:
|
| 1838 |
+
return self.model_executor.list_prompt_adapters()
|
| 1839 |
+
|
| 1840 |
+
def start_profile(self) -> None:
|
| 1841 |
+
self.model_executor.start_profile()
|
| 1842 |
+
|
| 1843 |
+
def stop_profile(self) -> None:
|
| 1844 |
+
self.model_executor.stop_profile()
|
| 1845 |
+
|
| 1846 |
+
def sleep(self, level: int = 1) -> None:
|
| 1847 |
+
assert self.vllm_config.model_config.enable_sleep_mode, (
|
| 1848 |
+
"Sleep mode is not enabled in the model config")
|
| 1849 |
+
self.model_executor.sleep(level=level)
|
| 1850 |
+
|
| 1851 |
+
def wake_up(self) -> None:
|
| 1852 |
+
assert self.vllm_config.model_config.enable_sleep_mode, (
|
| 1853 |
+
"Sleep mode is not enabled in the model config")
|
| 1854 |
+
self.model_executor.wake_up()
|
| 1855 |
+
|
| 1856 |
+
def check_health(self) -> None:
|
| 1857 |
+
if self.tokenizer:
|
| 1858 |
+
self.tokenizer.check_health()
|
| 1859 |
+
self.model_executor.check_health()
|
| 1860 |
+
|
| 1861 |
+
def is_tracing_enabled(self) -> bool:
|
| 1862 |
+
return self.tracer is not None
|
| 1863 |
+
|
| 1864 |
+
def do_tracing(self,
|
| 1865 |
+
scheduler_outputs: SchedulerOutputs,
|
| 1866 |
+
finished_before: Optional[List[int]] = None) -> None:
|
| 1867 |
+
if self.tracer is None:
|
| 1868 |
+
return
|
| 1869 |
+
|
| 1870 |
+
for idx, scheduled_seq_group in enumerate(
|
| 1871 |
+
scheduler_outputs.scheduled_seq_groups):
|
| 1872 |
+
# Skip double tracing when using async output proc
|
| 1873 |
+
if finished_before and idx in finished_before:
|
| 1874 |
+
continue
|
| 1875 |
+
|
| 1876 |
+
seq_group = scheduled_seq_group.seq_group
|
| 1877 |
+
if seq_group.is_finished():
|
| 1878 |
+
self.create_trace_span(seq_group)
|
| 1879 |
+
|
| 1880 |
+
def create_trace_span(self, seq_group: SequenceGroup) -> None:
|
| 1881 |
+
if self.tracer is None or seq_group.sampling_params is None:
|
| 1882 |
+
return
|
| 1883 |
+
arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
|
| 1884 |
+
|
| 1885 |
+
trace_context = extract_trace_context(seq_group.trace_headers)
|
| 1886 |
+
|
| 1887 |
+
with self.tracer.start_as_current_span(
|
| 1888 |
+
"llm_request",
|
| 1889 |
+
kind=SpanKind.SERVER,
|
| 1890 |
+
context=trace_context,
|
| 1891 |
+
start_time=arrival_time_nano_seconds) as seq_span:
|
| 1892 |
+
metrics = seq_group.metrics
|
| 1893 |
+
ttft = metrics.first_token_time - metrics.arrival_time
|
| 1894 |
+
e2e_time = metrics.finished_time - metrics.arrival_time
|
| 1895 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
|
| 1896 |
+
self.model_config.model)
|
| 1897 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
|
| 1898 |
+
seq_group.request_id)
|
| 1899 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
|
| 1900 |
+
seq_group.sampling_params.temperature)
|
| 1901 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
|
| 1902 |
+
seq_group.sampling_params.top_p)
|
| 1903 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
|
| 1904 |
+
seq_group.sampling_params.max_tokens)
|
| 1905 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
|
| 1906 |
+
seq_group.sampling_params.n)
|
| 1907 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
|
| 1908 |
+
seq_group.num_seqs())
|
| 1909 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
|
| 1910 |
+
len(seq_group.prompt_token_ids))
|
| 1911 |
+
seq_span.set_attribute(
|
| 1912 |
+
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
| 1913 |
+
sum([
|
| 1914 |
+
seq.get_output_len()
|
| 1915 |
+
for seq in seq_group.get_finished_seqs()
|
| 1916 |
+
]))
|
| 1917 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
|
| 1918 |
+
metrics.time_in_queue)
|
| 1919 |
+
seq_span.set_attribute(
|
| 1920 |
+
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
| 1921 |
+
seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
|
| 1922 |
+
if metrics.scheduler_time is not None:
|
| 1923 |
+
seq_span.set_attribute(
|
| 1924 |
+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
|
| 1925 |
+
metrics.scheduler_time)
|
| 1926 |
+
if metrics.model_forward_time is not None:
|
| 1927 |
+
seq_span.set_attribute(
|
| 1928 |
+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
|
| 1929 |
+
metrics.model_forward_time / 1000.0)
|
| 1930 |
+
if metrics.model_execute_time is not None:
|
| 1931 |
+
seq_span.set_attribute(
|
| 1932 |
+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
|
| 1933 |
+
metrics.model_execute_time)
|
| 1934 |
+
|
| 1935 |
+
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
| 1936 |
+
lora_request: Optional[LoRARequest]):
|
| 1937 |
+
if is_encoder_decoder_inputs(inputs):
|
| 1938 |
+
# For encoder-decoder multimodal models, the max_prompt_len
|
| 1939 |
+
# restricts the decoder prompt length
|
| 1940 |
+
prompt_inputs = inputs["decoder" if self.model_config.
|
| 1941 |
+
is_multimodal_model else "encoder"]
|
| 1942 |
+
else:
|
| 1943 |
+
prompt_inputs = inputs
|
| 1944 |
+
|
| 1945 |
+
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
| 1946 |
+
|
| 1947 |
+
if prompt_ids is None or len(prompt_ids) == 0:
|
| 1948 |
+
raise ValueError("Prompt cannot be empty")
|
| 1949 |
+
|
| 1950 |
+
if self.model_config.is_multimodal_model:
|
| 1951 |
+
max_prompt_len = self.model_config.max_model_len
|
| 1952 |
+
|
| 1953 |
+
if len(prompt_ids) > max_prompt_len:
|
| 1954 |
+
raise ValueError(
|
| 1955 |
+
f"The prompt (total length {len(prompt_ids)}) is too long "
|
| 1956 |
+
f"to fit into the model (context length {max_prompt_len}). "
|
| 1957 |
+
"Make sure that `max_model_len` is no smaller than the "
|
| 1958 |
+
"number of text tokens plus multimodal tokens. For image "
|
| 1959 |
+
"inputs, the number of image tokens depends on the number "
|
| 1960 |
+
"of images, and possibly their aspect ratios as well.")
|
| 1961 |
+
|
| 1962 |
+
# TODO: Find out how many placeholder tokens are there so we can
|
| 1963 |
+
# check that chunked prefill does not truncate them
|
| 1964 |
+
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
| 1965 |
+
|
| 1966 |
+
def _build_logits_processors(
|
| 1967 |
+
self, sampling_params: SamplingParams,
|
| 1968 |
+
lora_request: Optional[LoRARequest]) -> SamplingParams:
|
| 1969 |
+
"""Constructs logits processors based on the guided_decoding,
|
| 1970 |
+
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
| 1971 |
+
those fields and adds the constructed logits processors to the
|
| 1972 |
+
logits_processors field. Returns the modified sampling params."""
|
| 1973 |
+
|
| 1974 |
+
logits_processors = []
|
| 1975 |
+
|
| 1976 |
+
if sampling_params.guided_decoding is not None:
|
| 1977 |
+
# Defensively copy sampling params since guided decoding logits
|
| 1978 |
+
# processors can have different state for each request
|
| 1979 |
+
sampling_params = copy.copy(sampling_params)
|
| 1980 |
+
guided_decoding = sampling_params.guided_decoding
|
| 1981 |
+
|
| 1982 |
+
logger.debug(
|
| 1983 |
+
"Building guided decoding logits processor in "
|
| 1984 |
+
"LLMEngine. Params: %s", guided_decoding)
|
| 1985 |
+
|
| 1986 |
+
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
| 1987 |
+
guided_decoding.backend = guided_decoding.backend or \
|
| 1988 |
+
self.decoding_config.guided_decoding_backend
|
| 1989 |
+
|
| 1990 |
+
processor = get_local_guided_decoding_logits_processor(
|
| 1991 |
+
guided_params=guided_decoding,
|
| 1992 |
+
tokenizer=tokenizer,
|
| 1993 |
+
model_config=self.model_config)
|
| 1994 |
+
if processor:
|
| 1995 |
+
logits_processors.append(processor)
|
| 1996 |
+
|
| 1997 |
+
# Unset so this doesn't get passed down to the model
|
| 1998 |
+
sampling_params.guided_decoding = None
|
| 1999 |
+
|
| 2000 |
+
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
| 2001 |
+
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
| 2002 |
+
|
| 2003 |
+
processors = get_openai_logits_processors(
|
| 2004 |
+
logit_bias=sampling_params.logit_bias,
|
| 2005 |
+
allowed_token_ids=sampling_params.allowed_token_ids,
|
| 2006 |
+
tokenizer=tokenizer)
|
| 2007 |
+
logits_processors.extend(processors)
|
| 2008 |
+
|
| 2009 |
+
# Unset so these don't get passed down to the model
|
| 2010 |
+
sampling_params.logit_bias = None
|
| 2011 |
+
sampling_params.allowed_token_ids = None
|
| 2012 |
+
|
| 2013 |
+
if len(sampling_params.bad_words) > 0:
|
| 2014 |
+
tokenizer = self.get_tokenizer(lora_request)
|
| 2015 |
+
processors = get_bad_words_logits_processors(
|
| 2016 |
+
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
| 2017 |
+
logits_processors.extend(processors)
|
| 2018 |
+
|
| 2019 |
+
if logits_processors:
|
| 2020 |
+
if sampling_params.logits_processors is None:
|
| 2021 |
+
sampling_params.logits_processors = logits_processors
|
| 2022 |
+
else:
|
| 2023 |
+
sampling_params.logits_processors.extend(logits_processors)
|
| 2024 |
+
|
| 2025 |
+
return sampling_params
|
.venv/lib/python3.11/site-packages/vllm/engine/metrics.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
from typing import Counter as CollectionsCounter
|
| 6 |
+
from typing import Dict, List, Optional, Type, Union, cast
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import prometheus_client
|
| 10 |
+
|
| 11 |
+
from vllm.config import VllmConfig
|
| 12 |
+
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
|
| 13 |
+
SupportsMetricsInfo)
|
| 14 |
+
from vllm.executor.ray_utils import ray
|
| 15 |
+
from vllm.logger import init_logger
|
| 16 |
+
|
| 17 |
+
if ray is not None:
|
| 18 |
+
from ray.util import metrics as ray_metrics
|
| 19 |
+
else:
|
| 20 |
+
ray_metrics = None
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
| 24 |
+
|
| 25 |
+
logger = init_logger(__name__)
|
| 26 |
+
|
| 27 |
+
prometheus_client.disable_created_metrics()
|
| 28 |
+
|
| 29 |
+
# The begin-* and end* here are used by the documentation generator
|
| 30 |
+
# to extract the metrics definitions.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# begin-metrics-definitions
|
| 34 |
+
class Metrics:
|
| 35 |
+
"""
|
| 36 |
+
vLLM uses a multiprocessing-based frontend for the OpenAI server.
|
| 37 |
+
This means that we need to run prometheus_client in multiprocessing mode
|
| 38 |
+
See https://prometheus.github.io/client_python/multiprocess/ for more
|
| 39 |
+
details on limitations.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
labelname_finish_reason = "finished_reason"
|
| 43 |
+
labelname_waiting_lora_adapters = "waiting_lora_adapters"
|
| 44 |
+
labelname_running_lora_adapters = "running_lora_adapters"
|
| 45 |
+
labelname_max_lora = "max_lora"
|
| 46 |
+
_gauge_cls = prometheus_client.Gauge
|
| 47 |
+
_counter_cls = prometheus_client.Counter
|
| 48 |
+
_histogram_cls = prometheus_client.Histogram
|
| 49 |
+
|
| 50 |
+
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
|
| 51 |
+
# Unregister any existing vLLM collectors (for CI/CD)
|
| 52 |
+
self._unregister_vllm_metrics()
|
| 53 |
+
|
| 54 |
+
max_model_len = vllm_config.model_config.max_model_len
|
| 55 |
+
|
| 56 |
+
# System stats
|
| 57 |
+
# Scheduler State
|
| 58 |
+
self.gauge_scheduler_running = self._gauge_cls(
|
| 59 |
+
name="vllm:num_requests_running",
|
| 60 |
+
documentation="Number of requests currently running on GPU.",
|
| 61 |
+
labelnames=labelnames,
|
| 62 |
+
multiprocess_mode="sum")
|
| 63 |
+
self.gauge_scheduler_waiting = self._gauge_cls(
|
| 64 |
+
name="vllm:num_requests_waiting",
|
| 65 |
+
documentation="Number of requests waiting to be processed.",
|
| 66 |
+
labelnames=labelnames,
|
| 67 |
+
multiprocess_mode="sum")
|
| 68 |
+
self.gauge_lora_info = self._gauge_cls(
|
| 69 |
+
name="vllm:lora_requests_info",
|
| 70 |
+
documentation="Running stats on lora requests.",
|
| 71 |
+
labelnames=[
|
| 72 |
+
self.labelname_running_lora_adapters,
|
| 73 |
+
self.labelname_max_lora,
|
| 74 |
+
self.labelname_waiting_lora_adapters,
|
| 75 |
+
],
|
| 76 |
+
multiprocess_mode="livemostrecent",
|
| 77 |
+
)
|
| 78 |
+
self.gauge_scheduler_swapped = self._gauge_cls(
|
| 79 |
+
name="vllm:num_requests_swapped",
|
| 80 |
+
documentation="Number of requests swapped to CPU.",
|
| 81 |
+
labelnames=labelnames,
|
| 82 |
+
multiprocess_mode="sum")
|
| 83 |
+
# KV Cache Usage in %
|
| 84 |
+
self.gauge_gpu_cache_usage = self._gauge_cls(
|
| 85 |
+
name="vllm:gpu_cache_usage_perc",
|
| 86 |
+
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
| 87 |
+
labelnames=labelnames,
|
| 88 |
+
multiprocess_mode="sum")
|
| 89 |
+
self.gauge_cpu_cache_usage = self._gauge_cls(
|
| 90 |
+
name="vllm:cpu_cache_usage_perc",
|
| 91 |
+
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
| 92 |
+
labelnames=labelnames,
|
| 93 |
+
multiprocess_mode="sum")
|
| 94 |
+
# Prefix caching block hit rate
|
| 95 |
+
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
|
| 96 |
+
name="vllm:cpu_prefix_cache_hit_rate",
|
| 97 |
+
documentation="CPU prefix cache block hit rate.",
|
| 98 |
+
labelnames=labelnames,
|
| 99 |
+
multiprocess_mode="sum")
|
| 100 |
+
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
|
| 101 |
+
name="vllm:gpu_prefix_cache_hit_rate",
|
| 102 |
+
documentation="GPU prefix cache block hit rate.",
|
| 103 |
+
labelnames=labelnames,
|
| 104 |
+
multiprocess_mode="sum")
|
| 105 |
+
|
| 106 |
+
# Iteration stats
|
| 107 |
+
self.counter_num_preemption = self._counter_cls(
|
| 108 |
+
name="vllm:num_preemptions_total",
|
| 109 |
+
documentation="Cumulative number of preemption from the engine.",
|
| 110 |
+
labelnames=labelnames)
|
| 111 |
+
self.counter_prompt_tokens = self._counter_cls(
|
| 112 |
+
name="vllm:prompt_tokens_total",
|
| 113 |
+
documentation="Number of prefill tokens processed.",
|
| 114 |
+
labelnames=labelnames)
|
| 115 |
+
self.counter_generation_tokens = self._counter_cls(
|
| 116 |
+
name="vllm:generation_tokens_total",
|
| 117 |
+
documentation="Number of generation tokens processed.",
|
| 118 |
+
labelnames=labelnames)
|
| 119 |
+
self.counter_tokens = self._counter_cls(
|
| 120 |
+
name="vllm:tokens_total",
|
| 121 |
+
documentation="Number of prefill plus generation tokens processed.",
|
| 122 |
+
labelnames=labelnames)
|
| 123 |
+
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
|
| 124 |
+
if not vllm_config.model_config.enforce_eager:
|
| 125 |
+
buckets = vllm_config.compilation_config.\
|
| 126 |
+
cudagraph_capture_sizes.copy()
|
| 127 |
+
buckets.sort()
|
| 128 |
+
self.histogram_iteration_tokens = self._histogram_cls(
|
| 129 |
+
name="vllm:iteration_tokens_total",
|
| 130 |
+
documentation="Histogram of number of tokens per engine_step.",
|
| 131 |
+
labelnames=labelnames,
|
| 132 |
+
buckets=buckets)
|
| 133 |
+
self.histogram_time_to_first_token = self._histogram_cls(
|
| 134 |
+
name="vllm:time_to_first_token_seconds",
|
| 135 |
+
documentation="Histogram of time to first token in seconds.",
|
| 136 |
+
labelnames=labelnames,
|
| 137 |
+
buckets=[
|
| 138 |
+
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
| 139 |
+
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
| 140 |
+
])
|
| 141 |
+
self.histogram_time_per_output_token = self._histogram_cls(
|
| 142 |
+
name="vllm:time_per_output_token_seconds",
|
| 143 |
+
documentation="Histogram of time per output token in seconds.",
|
| 144 |
+
labelnames=labelnames,
|
| 145 |
+
buckets=[
|
| 146 |
+
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
| 147 |
+
1.0, 2.5
|
| 148 |
+
])
|
| 149 |
+
|
| 150 |
+
# Request stats
|
| 151 |
+
# Latency
|
| 152 |
+
request_latency_buckets = [
|
| 153 |
+
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
| 154 |
+
40.0, 50.0, 60.0
|
| 155 |
+
]
|
| 156 |
+
self.histogram_e2e_time_request = self._histogram_cls(
|
| 157 |
+
name="vllm:e2e_request_latency_seconds",
|
| 158 |
+
documentation="Histogram of end to end request latency in seconds.",
|
| 159 |
+
labelnames=labelnames,
|
| 160 |
+
buckets=request_latency_buckets)
|
| 161 |
+
self.histogram_queue_time_request = self._histogram_cls(
|
| 162 |
+
name="vllm:request_queue_time_seconds",
|
| 163 |
+
documentation=
|
| 164 |
+
"Histogram of time spent in WAITING phase for request.",
|
| 165 |
+
labelnames=labelnames,
|
| 166 |
+
buckets=request_latency_buckets)
|
| 167 |
+
self.histogram_inference_time_request = self._histogram_cls(
|
| 168 |
+
name="vllm:request_inference_time_seconds",
|
| 169 |
+
documentation=
|
| 170 |
+
"Histogram of time spent in RUNNING phase for request.",
|
| 171 |
+
labelnames=labelnames,
|
| 172 |
+
buckets=request_latency_buckets)
|
| 173 |
+
self.histogram_prefill_time_request = self._histogram_cls(
|
| 174 |
+
name="vllm:request_prefill_time_seconds",
|
| 175 |
+
documentation=
|
| 176 |
+
"Histogram of time spent in PREFILL phase for request.",
|
| 177 |
+
labelnames=labelnames,
|
| 178 |
+
buckets=request_latency_buckets)
|
| 179 |
+
self.histogram_decode_time_request = self._histogram_cls(
|
| 180 |
+
name="vllm:request_decode_time_seconds",
|
| 181 |
+
documentation=
|
| 182 |
+
"Histogram of time spent in DECODE phase for request.",
|
| 183 |
+
labelnames=labelnames,
|
| 184 |
+
buckets=request_latency_buckets)
|
| 185 |
+
self.histogram_time_in_queue_request = self._histogram_cls(
|
| 186 |
+
name="vllm:time_in_queue_requests",
|
| 187 |
+
documentation=
|
| 188 |
+
"Histogram of time the request spent in the queue in seconds.",
|
| 189 |
+
labelnames=labelnames,
|
| 190 |
+
buckets=request_latency_buckets)
|
| 191 |
+
self.histogram_model_forward_time_request = self._histogram_cls(
|
| 192 |
+
name="vllm:model_forward_time_milliseconds",
|
| 193 |
+
documentation=
|
| 194 |
+
"Histogram of time spent in the model forward pass in ms.",
|
| 195 |
+
labelnames=labelnames,
|
| 196 |
+
buckets=build_1_2_3_5_8_buckets(3000))
|
| 197 |
+
self.histogram_model_execute_time_request = self._histogram_cls(
|
| 198 |
+
name="vllm:model_execute_time_milliseconds",
|
| 199 |
+
documentation=
|
| 200 |
+
"Histogram of time spent in the model execute function in ms.",
|
| 201 |
+
labelnames=labelnames,
|
| 202 |
+
buckets=build_1_2_3_5_8_buckets(3000))
|
| 203 |
+
# Metadata
|
| 204 |
+
self.histogram_num_prompt_tokens_request = self._histogram_cls(
|
| 205 |
+
name="vllm:request_prompt_tokens",
|
| 206 |
+
documentation="Number of prefill tokens processed.",
|
| 207 |
+
labelnames=labelnames,
|
| 208 |
+
buckets=build_1_2_5_buckets(max_model_len),
|
| 209 |
+
)
|
| 210 |
+
self.histogram_num_generation_tokens_request = \
|
| 211 |
+
self._histogram_cls(
|
| 212 |
+
name="vllm:request_generation_tokens",
|
| 213 |
+
documentation="Number of generation tokens processed.",
|
| 214 |
+
labelnames=labelnames,
|
| 215 |
+
buckets=build_1_2_5_buckets(max_model_len),
|
| 216 |
+
)
|
| 217 |
+
self.histogram_max_num_generation_tokens_request = self._histogram_cls(
|
| 218 |
+
name="vllm:request_max_num_generation_tokens",
|
| 219 |
+
documentation=
|
| 220 |
+
"Histogram of maximum number of requested generation tokens.",
|
| 221 |
+
labelnames=labelnames,
|
| 222 |
+
buckets=build_1_2_5_buckets(max_model_len))
|
| 223 |
+
self.histogram_n_request = self._histogram_cls(
|
| 224 |
+
name="vllm:request_params_n",
|
| 225 |
+
documentation="Histogram of the n request parameter.",
|
| 226 |
+
labelnames=labelnames,
|
| 227 |
+
buckets=[1, 2, 5, 10, 20],
|
| 228 |
+
)
|
| 229 |
+
self.histogram_max_tokens_request = self._histogram_cls(
|
| 230 |
+
name="vllm:request_params_max_tokens",
|
| 231 |
+
documentation="Histogram of the max_tokens request parameter.",
|
| 232 |
+
labelnames=labelnames,
|
| 233 |
+
buckets=build_1_2_5_buckets(max_model_len),
|
| 234 |
+
)
|
| 235 |
+
self.counter_request_success = self._counter_cls(
|
| 236 |
+
name="vllm:request_success_total",
|
| 237 |
+
documentation="Count of successfully processed requests.",
|
| 238 |
+
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
| 239 |
+
|
| 240 |
+
# Speculatie decoding stats
|
| 241 |
+
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
| 242 |
+
name="vllm:spec_decode_draft_acceptance_rate",
|
| 243 |
+
documentation="Speulative token acceptance rate.",
|
| 244 |
+
labelnames=labelnames,
|
| 245 |
+
multiprocess_mode="sum")
|
| 246 |
+
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
| 247 |
+
name="vllm:spec_decode_efficiency",
|
| 248 |
+
documentation="Speculative decoding system efficiency.",
|
| 249 |
+
labelnames=labelnames,
|
| 250 |
+
multiprocess_mode="sum")
|
| 251 |
+
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
| 252 |
+
name="vllm:spec_decode_num_accepted_tokens_total",
|
| 253 |
+
documentation="Number of accepted tokens.",
|
| 254 |
+
labelnames=labelnames))
|
| 255 |
+
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
|
| 256 |
+
name="vllm:spec_decode_num_draft_tokens_total",
|
| 257 |
+
documentation="Number of draft tokens.",
|
| 258 |
+
labelnames=labelnames)
|
| 259 |
+
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
|
| 260 |
+
name="vllm:spec_decode_num_emitted_tokens_total",
|
| 261 |
+
documentation="Number of emitted tokens.",
|
| 262 |
+
labelnames=labelnames))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# end-metrics-definitions
|
| 266 |
+
|
| 267 |
+
def _unregister_vllm_metrics(self) -> None:
|
| 268 |
+
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
| 269 |
+
if hasattr(collector, "_name") and "vllm" in collector._name:
|
| 270 |
+
prometheus_client.REGISTRY.unregister(collector)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class _RayGaugeWrapper:
|
| 274 |
+
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
| 275 |
+
prometheus_client.Gauge"""
|
| 276 |
+
|
| 277 |
+
def __init__(self,
|
| 278 |
+
name: str,
|
| 279 |
+
documentation: str = "",
|
| 280 |
+
labelnames: Optional[List[str]] = None,
|
| 281 |
+
multiprocess_mode: str = ""):
|
| 282 |
+
del multiprocess_mode
|
| 283 |
+
labelnames_tuple = tuple(labelnames) if labelnames else None
|
| 284 |
+
self._gauge = ray_metrics.Gauge(name=name,
|
| 285 |
+
description=documentation,
|
| 286 |
+
tag_keys=labelnames_tuple)
|
| 287 |
+
|
| 288 |
+
def labels(self, **labels):
|
| 289 |
+
self._gauge.set_default_tags(labels)
|
| 290 |
+
return self
|
| 291 |
+
|
| 292 |
+
def set(self, value: Union[int, float]):
|
| 293 |
+
return self._gauge.set(value)
|
| 294 |
+
|
| 295 |
+
def set_to_current_time(self):
|
| 296 |
+
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
|
| 297 |
+
return self._gauge.set(time.time())
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class _RayCounterWrapper:
|
| 301 |
+
"""Wraps around ray.util.metrics.Counter to provide same API as
|
| 302 |
+
prometheus_client.Counter"""
|
| 303 |
+
|
| 304 |
+
def __init__(self,
|
| 305 |
+
name: str,
|
| 306 |
+
documentation: str = "",
|
| 307 |
+
labelnames: Optional[List[str]] = None):
|
| 308 |
+
labelnames_tuple = tuple(labelnames) if labelnames else None
|
| 309 |
+
self._counter = ray_metrics.Counter(name=name,
|
| 310 |
+
description=documentation,
|
| 311 |
+
tag_keys=labelnames_tuple)
|
| 312 |
+
|
| 313 |
+
def labels(self, **labels):
|
| 314 |
+
self._counter.set_default_tags(labels)
|
| 315 |
+
return self
|
| 316 |
+
|
| 317 |
+
def inc(self, value: Union[int, float] = 1.0):
|
| 318 |
+
if value == 0:
|
| 319 |
+
return
|
| 320 |
+
return self._counter.inc(value)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class _RayHistogramWrapper:
|
| 324 |
+
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
| 325 |
+
prometheus_client.Histogram"""
|
| 326 |
+
|
| 327 |
+
def __init__(self,
|
| 328 |
+
name: str,
|
| 329 |
+
documentation: str = "",
|
| 330 |
+
labelnames: Optional[List[str]] = None,
|
| 331 |
+
buckets: Optional[List[float]] = None):
|
| 332 |
+
labelnames_tuple = tuple(labelnames) if labelnames else None
|
| 333 |
+
boundaries = buckets if buckets else []
|
| 334 |
+
self._histogram = ray_metrics.Histogram(name=name,
|
| 335 |
+
description=documentation,
|
| 336 |
+
tag_keys=labelnames_tuple,
|
| 337 |
+
boundaries=boundaries)
|
| 338 |
+
|
| 339 |
+
def labels(self, **labels):
|
| 340 |
+
self._histogram.set_default_tags(labels)
|
| 341 |
+
return self
|
| 342 |
+
|
| 343 |
+
def observe(self, value: Union[int, float]):
|
| 344 |
+
return self._histogram.observe(value)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class RayMetrics(Metrics):
|
| 348 |
+
"""
|
| 349 |
+
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
|
| 350 |
+
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
| 351 |
+
"""
|
| 352 |
+
_gauge_cls: Type[prometheus_client.Gauge] = cast(
|
| 353 |
+
Type[prometheus_client.Gauge], _RayGaugeWrapper)
|
| 354 |
+
_counter_cls: Type[prometheus_client.Counter] = cast(
|
| 355 |
+
Type[prometheus_client.Counter], _RayCounterWrapper)
|
| 356 |
+
_histogram_cls: Type[prometheus_client.Histogram] = cast(
|
| 357 |
+
Type[prometheus_client.Histogram], _RayHistogramWrapper)
|
| 358 |
+
|
| 359 |
+
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
|
| 360 |
+
if ray_metrics is None:
|
| 361 |
+
raise ImportError("RayMetrics requires Ray to be installed.")
|
| 362 |
+
super().__init__(labelnames, vllm_config)
|
| 363 |
+
|
| 364 |
+
def _unregister_vllm_metrics(self) -> None:
|
| 365 |
+
# No-op on purpose
|
| 366 |
+
pass
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
|
| 370 |
+
"""
|
| 371 |
+
Builds a list of buckets with increasing powers of 10 multiplied by
|
| 372 |
+
mantissa values until the value exceeds the specified maximum.
|
| 373 |
+
|
| 374 |
+
"""
|
| 375 |
+
exponent = 0
|
| 376 |
+
buckets: List[int] = []
|
| 377 |
+
while True:
|
| 378 |
+
for m in mantissa_lst:
|
| 379 |
+
value = m * 10**exponent
|
| 380 |
+
if value <= max_value:
|
| 381 |
+
buckets.append(value)
|
| 382 |
+
else:
|
| 383 |
+
return buckets
|
| 384 |
+
exponent += 1
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
| 388 |
+
"""
|
| 389 |
+
Example:
|
| 390 |
+
>>> build_1_2_5_buckets(100)
|
| 391 |
+
[1, 2, 5, 10, 20, 50, 100]
|
| 392 |
+
"""
|
| 393 |
+
return build_buckets([1, 2, 5], max_value)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
|
| 397 |
+
"""
|
| 398 |
+
Example:
|
| 399 |
+
>>> build_1_2_3_5_8_buckets(100)
|
| 400 |
+
[1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100]
|
| 401 |
+
"""
|
| 402 |
+
return build_buckets([1, 2, 3, 5, 8], max_value)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def local_interval_elapsed(now: float, last_log: float,
|
| 406 |
+
local_interval: float) -> bool:
|
| 407 |
+
elapsed_time = now - last_log
|
| 408 |
+
return elapsed_time > local_interval
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def get_throughput(tracked_stats: List[int], now: float,
|
| 412 |
+
last_log: float) -> float:
|
| 413 |
+
return float(np.sum(tracked_stats) / (now - last_log))
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class LoggingStatLogger(StatLoggerBase):
|
| 417 |
+
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
| 418 |
+
|
| 419 |
+
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
|
| 420 |
+
super().__init__(local_interval, vllm_config)
|
| 421 |
+
self.last_prompt_throughput: Optional[float] = None
|
| 422 |
+
self.last_generation_throughput: Optional[float] = None
|
| 423 |
+
|
| 424 |
+
def log(self, stats: Stats) -> None:
|
| 425 |
+
"""Called by LLMEngine.
|
| 426 |
+
Logs to Stdout every self.local_interval seconds."""
|
| 427 |
+
|
| 428 |
+
# Save tracked stats for token counters.
|
| 429 |
+
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
| 430 |
+
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
| 431 |
+
|
| 432 |
+
# Update spec decode metrics
|
| 433 |
+
self.maybe_update_spec_decode_metrics(stats)
|
| 434 |
+
|
| 435 |
+
# Log locally every local_interval seconds.
|
| 436 |
+
if local_interval_elapsed(stats.now, self.last_local_log,
|
| 437 |
+
self.local_interval):
|
| 438 |
+
# Compute summary metrics for tracked stats (and log them
|
| 439 |
+
# to promethus if applicable).
|
| 440 |
+
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
| 441 |
+
now=stats.now,
|
| 442 |
+
last_log=self.last_local_log)
|
| 443 |
+
generation_throughput = get_throughput(
|
| 444 |
+
self.num_generation_tokens,
|
| 445 |
+
now=stats.now,
|
| 446 |
+
last_log=self.last_local_log)
|
| 447 |
+
|
| 448 |
+
log_fn = logger.info
|
| 449 |
+
if not any((prompt_throughput, generation_throughput,
|
| 450 |
+
self.last_prompt_throughput,
|
| 451 |
+
self.last_generation_throughput)):
|
| 452 |
+
# Avoid log noise on an idle production system
|
| 453 |
+
log_fn = logger.debug
|
| 454 |
+
|
| 455 |
+
log_fn(
|
| 456 |
+
"Avg prompt throughput: %.1f tokens/s, "
|
| 457 |
+
"Avg generation throughput: %.1f tokens/s, "
|
| 458 |
+
"Running: %d reqs, Swapped: %d reqs, "
|
| 459 |
+
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
| 460 |
+
"CPU KV cache usage: %.1f%%.",
|
| 461 |
+
prompt_throughput,
|
| 462 |
+
generation_throughput,
|
| 463 |
+
stats.num_running_sys,
|
| 464 |
+
stats.num_swapped_sys,
|
| 465 |
+
stats.num_waiting_sys,
|
| 466 |
+
stats.gpu_cache_usage_sys * 100,
|
| 467 |
+
stats.cpu_cache_usage_sys * 100,
|
| 468 |
+
)
|
| 469 |
+
if (stats.cpu_prefix_cache_hit_rate >= 0
|
| 470 |
+
or stats.gpu_prefix_cache_hit_rate >= 0):
|
| 471 |
+
log_fn(
|
| 472 |
+
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
|
| 473 |
+
stats.gpu_prefix_cache_hit_rate * 100,
|
| 474 |
+
stats.cpu_prefix_cache_hit_rate * 100,
|
| 475 |
+
)
|
| 476 |
+
if self.spec_decode_metrics is not None:
|
| 477 |
+
log_fn(
|
| 478 |
+
self._format_spec_decode_metrics_str(
|
| 479 |
+
self.spec_decode_metrics))
|
| 480 |
+
|
| 481 |
+
self._reset(stats, prompt_throughput, generation_throughput)
|
| 482 |
+
|
| 483 |
+
def _reset(self, stats, prompt_throughput, generation_throughput) -> None:
|
| 484 |
+
# Reset tracked stats for next interval.
|
| 485 |
+
self.num_prompt_tokens = []
|
| 486 |
+
self.num_generation_tokens = []
|
| 487 |
+
self.last_local_log = stats.now
|
| 488 |
+
self.spec_decode_metrics = None
|
| 489 |
+
self.last_prompt_throughput = prompt_throughput
|
| 490 |
+
self.last_generation_throughput = generation_throughput
|
| 491 |
+
|
| 492 |
+
def _format_spec_decode_metrics_str(
|
| 493 |
+
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
| 494 |
+
|
| 495 |
+
return ("Speculative metrics: "
|
| 496 |
+
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
| 497 |
+
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
| 498 |
+
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
| 499 |
+
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
| 500 |
+
f"Number of draft tokens: {metrics.draft_tokens}, "
|
| 501 |
+
f"Number of emitted tokens: {metrics.emitted_tokens}.")
|
| 502 |
+
|
| 503 |
+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
| 504 |
+
raise NotImplementedError
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class PrometheusStatLogger(StatLoggerBase):
|
| 508 |
+
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
| 509 |
+
_metrics_cls = Metrics
|
| 510 |
+
_gauge_cls = prometheus_client.Gauge
|
| 511 |
+
|
| 512 |
+
def __init__(self, local_interval: float, labels: Dict[str, str],
|
| 513 |
+
vllm_config: VllmConfig) -> None:
|
| 514 |
+
super().__init__(local_interval, vllm_config)
|
| 515 |
+
# Prometheus metrics
|
| 516 |
+
self.labels = labels
|
| 517 |
+
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
|
| 518 |
+
vllm_config=vllm_config)
|
| 519 |
+
|
| 520 |
+
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
| 521 |
+
# Convenience function for logging to gauge.
|
| 522 |
+
gauge.labels(**self.labels).set(data)
|
| 523 |
+
|
| 524 |
+
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
| 525 |
+
# Convenience function for logging to counter.
|
| 526 |
+
# Prevent ValueError from negative increment
|
| 527 |
+
if data < 0:
|
| 528 |
+
logger.warning("Skipping negative increment of %g to %s", data,
|
| 529 |
+
counter)
|
| 530 |
+
return
|
| 531 |
+
counter.labels(**self.labels).inc(data)
|
| 532 |
+
|
| 533 |
+
def _log_counter_labels(self, counter, data: CollectionsCounter,
|
| 534 |
+
label_key: str) -> None:
|
| 535 |
+
# Convenience function for collection counter of labels.
|
| 536 |
+
for label, count in data.items():
|
| 537 |
+
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
| 538 |
+
|
| 539 |
+
def _log_histogram(self, histogram, data: Union[List[int],
|
| 540 |
+
List[float]]) -> None:
|
| 541 |
+
# Convenience function for logging list to histogram.
|
| 542 |
+
for datum in data:
|
| 543 |
+
histogram.labels(**self.labels).observe(datum)
|
| 544 |
+
|
| 545 |
+
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
|
| 546 |
+
gauge.labels(**data).set_to_current_time()
|
| 547 |
+
|
| 548 |
+
def _log_prometheus(self, stats: Stats) -> None:
|
| 549 |
+
# System state data
|
| 550 |
+
self._log_gauge(self.metrics.gauge_scheduler_running,
|
| 551 |
+
stats.num_running_sys)
|
| 552 |
+
self._log_gauge(self.metrics.gauge_scheduler_swapped,
|
| 553 |
+
stats.num_swapped_sys)
|
| 554 |
+
self._log_gauge(self.metrics.gauge_scheduler_waiting,
|
| 555 |
+
stats.num_waiting_sys)
|
| 556 |
+
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
|
| 557 |
+
stats.gpu_cache_usage_sys)
|
| 558 |
+
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
|
| 559 |
+
stats.cpu_cache_usage_sys)
|
| 560 |
+
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
|
| 561 |
+
stats.cpu_prefix_cache_hit_rate)
|
| 562 |
+
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
|
| 563 |
+
stats.gpu_prefix_cache_hit_rate)
|
| 564 |
+
# Including max-lora in metric, in future this property of lora
|
| 565 |
+
# config maybe extended to be dynamic.
|
| 566 |
+
lora_info = {
|
| 567 |
+
self.metrics.labelname_running_lora_adapters:
|
| 568 |
+
",".join(stats.running_lora_adapters),
|
| 569 |
+
self.metrics.labelname_waiting_lora_adapters:
|
| 570 |
+
",".join(stats.waiting_lora_adapters),
|
| 571 |
+
self.metrics.labelname_max_lora:
|
| 572 |
+
stats.max_lora,
|
| 573 |
+
}
|
| 574 |
+
self._log_gauge_string(self.metrics.gauge_lora_info, lora_info)
|
| 575 |
+
# Iteration level data
|
| 576 |
+
self._log_counter(self.metrics.counter_num_preemption,
|
| 577 |
+
stats.num_preemption_iter)
|
| 578 |
+
self._log_counter(self.metrics.counter_prompt_tokens,
|
| 579 |
+
stats.num_prompt_tokens_iter)
|
| 580 |
+
self._log_counter(self.metrics.counter_generation_tokens,
|
| 581 |
+
stats.num_generation_tokens_iter)
|
| 582 |
+
self._log_histogram(self.metrics.histogram_iteration_tokens,
|
| 583 |
+
[stats.num_tokens_iter])
|
| 584 |
+
self._log_histogram(self.metrics.histogram_time_to_first_token,
|
| 585 |
+
stats.time_to_first_tokens_iter)
|
| 586 |
+
self._log_histogram(self.metrics.histogram_time_per_output_token,
|
| 587 |
+
stats.time_per_output_tokens_iter)
|
| 588 |
+
|
| 589 |
+
# Request level data
|
| 590 |
+
# Latency
|
| 591 |
+
self._log_histogram(self.metrics.histogram_e2e_time_request,
|
| 592 |
+
stats.time_e2e_requests)
|
| 593 |
+
self._log_histogram(self.metrics.histogram_queue_time_request,
|
| 594 |
+
stats.time_queue_requests)
|
| 595 |
+
self._log_histogram(self.metrics.histogram_inference_time_request,
|
| 596 |
+
stats.time_inference_requests)
|
| 597 |
+
self._log_histogram(self.metrics.histogram_prefill_time_request,
|
| 598 |
+
stats.time_prefill_requests)
|
| 599 |
+
self._log_histogram(self.metrics.histogram_decode_time_request,
|
| 600 |
+
stats.time_decode_requests)
|
| 601 |
+
self._log_histogram(self.metrics.histogram_time_in_queue_request,
|
| 602 |
+
stats.time_in_queue_requests)
|
| 603 |
+
self._log_histogram(self.metrics.histogram_model_forward_time_request,
|
| 604 |
+
stats.model_forward_time_requests)
|
| 605 |
+
self._log_histogram(self.metrics.histogram_model_execute_time_request,
|
| 606 |
+
stats.model_execute_time_requests)
|
| 607 |
+
# Metadata
|
| 608 |
+
finished_reason_counter = CollectionsCounter(
|
| 609 |
+
stats.finished_reason_requests)
|
| 610 |
+
self._log_counter_labels(self.metrics.counter_request_success,
|
| 611 |
+
finished_reason_counter,
|
| 612 |
+
Metrics.labelname_finish_reason)
|
| 613 |
+
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
|
| 614 |
+
stats.num_prompt_tokens_requests)
|
| 615 |
+
self._log_histogram(
|
| 616 |
+
self.metrics.histogram_num_generation_tokens_request,
|
| 617 |
+
stats.num_generation_tokens_requests)
|
| 618 |
+
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
|
| 619 |
+
self._log_histogram(
|
| 620 |
+
self.metrics.histogram_max_num_generation_tokens_request,
|
| 621 |
+
stats.max_num_generation_tokens_requests)
|
| 622 |
+
self._log_histogram(self.metrics.histogram_max_tokens_request,
|
| 623 |
+
stats.max_tokens_requests)
|
| 624 |
+
|
| 625 |
+
def log(self, stats: Stats):
|
| 626 |
+
"""Logs to prometheus and tracked stats every iteration."""
|
| 627 |
+
# Log to prometheus.
|
| 628 |
+
self._log_prometheus(stats)
|
| 629 |
+
|
| 630 |
+
# Save tracked stats for token counters.
|
| 631 |
+
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
| 632 |
+
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
| 633 |
+
|
| 634 |
+
# Update spec decode metrics
|
| 635 |
+
self.maybe_update_spec_decode_metrics(stats)
|
| 636 |
+
|
| 637 |
+
# Log locally every local_interval seconds.
|
| 638 |
+
if local_interval_elapsed(stats.now, self.last_local_log,
|
| 639 |
+
self.local_interval):
|
| 640 |
+
if self.spec_decode_metrics is not None:
|
| 641 |
+
self._log_gauge(
|
| 642 |
+
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
| 643 |
+
self.spec_decode_metrics.draft_acceptance_rate)
|
| 644 |
+
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
| 645 |
+
self.spec_decode_metrics.system_efficiency)
|
| 646 |
+
self._log_counter(
|
| 647 |
+
self.metrics.counter_spec_decode_num_accepted_tokens,
|
| 648 |
+
self.spec_decode_metrics.accepted_tokens)
|
| 649 |
+
self._log_counter(
|
| 650 |
+
self.metrics.counter_spec_decode_num_draft_tokens,
|
| 651 |
+
self.spec_decode_metrics.draft_tokens)
|
| 652 |
+
self._log_counter(
|
| 653 |
+
self.metrics.counter_spec_decode_num_emitted_tokens,
|
| 654 |
+
self.spec_decode_metrics.emitted_tokens)
|
| 655 |
+
|
| 656 |
+
# Reset tracked stats for next interval.
|
| 657 |
+
self.num_prompt_tokens = []
|
| 658 |
+
self.num_generation_tokens = []
|
| 659 |
+
self.last_local_log = stats.now
|
| 660 |
+
self.spec_decode_metrics = None
|
| 661 |
+
|
| 662 |
+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
| 663 |
+
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
| 664 |
+
# Since prometheus multiprocessing mode does not support Info, emulate
|
| 665 |
+
# info here with a gauge.
|
| 666 |
+
if type == "cache_config":
|
| 667 |
+
metrics_info = obj.metrics_info()
|
| 668 |
+
info_gauge = self._gauge_cls(
|
| 669 |
+
name="vllm:cache_config_info",
|
| 670 |
+
documentation="Information of the LLMEngine CacheConfig",
|
| 671 |
+
labelnames=metrics_info.keys(),
|
| 672 |
+
multiprocess_mode="mostrecent")
|
| 673 |
+
info_gauge.labels(**metrics_info).set(1)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class RayPrometheusStatLogger(PrometheusStatLogger):
|
| 677 |
+
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
| 678 |
+
_metrics_cls = RayMetrics
|
| 679 |
+
|
| 680 |
+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
| 681 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/engine/metrics_types.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
These types are defined in this file to avoid importing vllm.engine.metrics
|
| 4 |
+
and therefore importing prometheus_client.
|
| 5 |
+
|
| 6 |
+
This is required due to usage of Prometheus multiprocess mode to enable
|
| 7 |
+
metrics after splitting out the uvicorn process from the engine process.
|
| 8 |
+
|
| 9 |
+
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
|
| 10 |
+
before prometheus_client is imported. Typically, this is done by setting
|
| 11 |
+
the env variable before launch, but since we are a library, we need to
|
| 12 |
+
do this in Python code and lazily import prometheus_client.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Dict, List, Optional, Protocol
|
| 19 |
+
|
| 20 |
+
from vllm.config import VllmConfig
|
| 21 |
+
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Stats:
|
| 26 |
+
"""Created by LLMEngine for use by StatLogger."""
|
| 27 |
+
now: float
|
| 28 |
+
|
| 29 |
+
# System stats (should have _sys suffix)
|
| 30 |
+
# Scheduler State
|
| 31 |
+
num_running_sys: int
|
| 32 |
+
num_waiting_sys: int
|
| 33 |
+
num_swapped_sys: int
|
| 34 |
+
# KV Cache Usage in %
|
| 35 |
+
gpu_cache_usage_sys: float
|
| 36 |
+
cpu_cache_usage_sys: float
|
| 37 |
+
# Prefix caching block hit rate
|
| 38 |
+
cpu_prefix_cache_hit_rate: float
|
| 39 |
+
gpu_prefix_cache_hit_rate: float
|
| 40 |
+
|
| 41 |
+
# Iteration stats (should have _iter suffix)
|
| 42 |
+
num_prompt_tokens_iter: int
|
| 43 |
+
num_generation_tokens_iter: int
|
| 44 |
+
num_tokens_iter: int
|
| 45 |
+
time_to_first_tokens_iter: List[float]
|
| 46 |
+
time_per_output_tokens_iter: List[float]
|
| 47 |
+
num_preemption_iter: int
|
| 48 |
+
|
| 49 |
+
# Request stats (should have _requests suffix)
|
| 50 |
+
# Latency
|
| 51 |
+
time_e2e_requests: List[float]
|
| 52 |
+
time_queue_requests: List[float]
|
| 53 |
+
time_inference_requests: List[float]
|
| 54 |
+
time_prefill_requests: List[float]
|
| 55 |
+
time_decode_requests: List[float]
|
| 56 |
+
time_in_queue_requests: List[float]
|
| 57 |
+
model_forward_time_requests: List[float]
|
| 58 |
+
model_execute_time_requests: List[float]
|
| 59 |
+
# Metadata
|
| 60 |
+
num_prompt_tokens_requests: List[int]
|
| 61 |
+
num_generation_tokens_requests: List[int]
|
| 62 |
+
n_requests: List[int]
|
| 63 |
+
max_num_generation_tokens_requests: List[int]
|
| 64 |
+
max_tokens_requests: List[int]
|
| 65 |
+
finished_reason_requests: List[str]
|
| 66 |
+
waiting_lora_adapters: List[str]
|
| 67 |
+
running_lora_adapters: List[str]
|
| 68 |
+
max_lora: str
|
| 69 |
+
|
| 70 |
+
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SupportsMetricsInfo(Protocol):
|
| 74 |
+
|
| 75 |
+
def metrics_info(self) -> Dict[str, str]:
|
| 76 |
+
...
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class StatLoggerBase(ABC):
|
| 80 |
+
"""Base class for StatLogger."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
|
| 83 |
+
# Tracked stats over current local logging interval.
|
| 84 |
+
self.num_prompt_tokens: List[int] = []
|
| 85 |
+
self.num_generation_tokens: List[int] = []
|
| 86 |
+
self.last_local_log = time.time()
|
| 87 |
+
self.local_interval = local_interval
|
| 88 |
+
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def log(self, stats: Stats) -> None:
|
| 92 |
+
raise NotImplementedError
|
| 93 |
+
|
| 94 |
+
@abstractmethod
|
| 95 |
+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
| 99 |
+
"""Save spec decode metrics (since they are unlikely
|
| 100 |
+
to be emitted at same time as log interval)."""
|
| 101 |
+
if stats.spec_decode_metrics is not None:
|
| 102 |
+
self.spec_decode_metrics = stats.spec_decode_metrics
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__init__.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import List, Mapping, Optional, Union, overload
|
| 7 |
+
|
| 8 |
+
from typing_extensions import deprecated
|
| 9 |
+
|
| 10 |
+
from vllm import PoolingParams
|
| 11 |
+
from vllm.inputs import PromptType
|
| 12 |
+
from vllm.lora.request import LoRARequest
|
| 13 |
+
from vllm.outputs import RequestOutput
|
| 14 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 15 |
+
from vllm.sampling_params import SamplingParams
|
| 16 |
+
from vllm.utils import deprecate_kwargs
|
| 17 |
+
|
| 18 |
+
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
| 19 |
+
|
| 20 |
+
IPC_INPUT_EXT = "_input_socket"
|
| 21 |
+
IPC_OUTPUT_EXT = "_output_socket"
|
| 22 |
+
IPC_HEALTH_EXT = "_health_socket"
|
| 23 |
+
IPC_DATA_EXT = "_data_socket"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MQEngineDeadError(RuntimeError):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class RPCProcessRequest:
|
| 32 |
+
prompt: PromptType
|
| 33 |
+
params: Union[SamplingParams, PoolingParams]
|
| 34 |
+
request_id: str
|
| 35 |
+
lora_request: Optional[LoRARequest] = None
|
| 36 |
+
trace_headers: Optional[Mapping[str, str]] = None
|
| 37 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
| 38 |
+
priority: int = 0
|
| 39 |
+
|
| 40 |
+
@overload
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
prompt: PromptType,
|
| 44 |
+
params: Union[SamplingParams, PoolingParams],
|
| 45 |
+
request_id: str,
|
| 46 |
+
lora_request: Optional[LoRARequest] = None,
|
| 47 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 48 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 49 |
+
priority: int = 0,
|
| 50 |
+
) -> None:
|
| 51 |
+
...
|
| 52 |
+
|
| 53 |
+
@overload
|
| 54 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
*,
|
| 58 |
+
inputs: PromptType,
|
| 59 |
+
params: Union[SamplingParams, PoolingParams],
|
| 60 |
+
request_id: str,
|
| 61 |
+
lora_request: Optional[LoRARequest] = None,
|
| 62 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 63 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 64 |
+
priority: int = 0,
|
| 65 |
+
) -> None:
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
@deprecate_kwargs(
|
| 69 |
+
"inputs",
|
| 70 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 71 |
+
)
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
prompt: Optional[PromptType] = None,
|
| 75 |
+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
| 76 |
+
request_id: Optional[str] = None,
|
| 77 |
+
lora_request: Optional[LoRARequest] = None,
|
| 78 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 79 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 80 |
+
priority: int = 0,
|
| 81 |
+
*,
|
| 82 |
+
inputs: Optional[PromptType] = None, # DEPRECATED
|
| 83 |
+
) -> None:
|
| 84 |
+
if inputs is not None:
|
| 85 |
+
prompt = inputs
|
| 86 |
+
assert (prompt is not None and params is not None
|
| 87 |
+
and request_id is not None)
|
| 88 |
+
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
self.prompt = prompt
|
| 92 |
+
self.params = params
|
| 93 |
+
self.request_id = request_id
|
| 94 |
+
self.lora_request = lora_request
|
| 95 |
+
self.trace_headers = trace_headers
|
| 96 |
+
self.prompt_adapter_request = prompt_adapter_request
|
| 97 |
+
self.priority = priority
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class RPCError:
|
| 102 |
+
request_id: Optional[str]
|
| 103 |
+
is_engine_errored: bool
|
| 104 |
+
exception: BaseException
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class RPCAbortRequest:
|
| 109 |
+
request_id: str
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class RPCStartupRequest(Enum):
|
| 113 |
+
IS_SERVER_READY = 1
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class RPCStartupResponse:
|
| 118 |
+
tracing_enabled: bool
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class RPCUProfileRequest(Enum):
|
| 122 |
+
START_PROFILE = 1
|
| 123 |
+
STOP_PROFILE = 2
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RPCResetPrefixCacheRequest(Enum):
|
| 127 |
+
RESET_PREFIX_CACHE = 1
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@dataclass
|
| 131 |
+
class RPCLoadAdapterRequest:
|
| 132 |
+
lora_request: LoRARequest
|
| 133 |
+
# Set the default value of request_id to a new UUID
|
| 134 |
+
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class RPCAdapterLoadedResponse:
|
| 139 |
+
request_id: str
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
| 143 |
+
RPCUProfileRequest, RPCLoadAdapterRequest,
|
| 144 |
+
RPCResetPrefixCacheRequest]
|
| 145 |
+
|
| 146 |
+
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
| 147 |
+
RPCError]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def ENGINE_DEAD_ERROR(
|
| 151 |
+
error: Optional[BaseException] = None) -> MQEngineDeadError:
|
| 152 |
+
if error is None:
|
| 153 |
+
return MQEngineDeadError(
|
| 154 |
+
"Engine loop is not running. Inspect the stacktrace to "
|
| 155 |
+
"find the original error")
|
| 156 |
+
|
| 157 |
+
return MQEngineDeadError(
|
| 158 |
+
"Engine loop is not running. Inspect the stacktrace to "
|
| 159 |
+
f"find the original error: {repr(error)}.")
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/client.cpython-311.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/__pycache__/engine.cpython-311.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/client.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import copy
|
| 5 |
+
import pickle
|
| 6 |
+
from contextlib import contextmanager, suppress
|
| 7 |
+
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
| 8 |
+
Optional, Union, cast, overload)
|
| 9 |
+
|
| 10 |
+
import cloudpickle
|
| 11 |
+
import psutil
|
| 12 |
+
import zmq
|
| 13 |
+
import zmq.asyncio
|
| 14 |
+
from typing_extensions import deprecated
|
| 15 |
+
from zmq import Frame # type: ignore[attr-defined]
|
| 16 |
+
from zmq.asyncio import Socket
|
| 17 |
+
|
| 18 |
+
from vllm import PoolingParams
|
| 19 |
+
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
| 20 |
+
from vllm.core.scheduler import SchedulerOutputs
|
| 21 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
| 22 |
+
# yapf conflicts with isort for this block
|
| 23 |
+
# yapf: disable
|
| 24 |
+
from vllm.engine.async_llm_engine import (
|
| 25 |
+
build_guided_decoding_logits_processor_async)
|
| 26 |
+
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
| 27 |
+
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
| 28 |
+
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
| 29 |
+
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
| 30 |
+
RPCAdapterLoadedResponse, RPCError,
|
| 31 |
+
RPCLoadAdapterRequest,
|
| 32 |
+
RPCProcessRequest,
|
| 33 |
+
RPCResetPrefixCacheRequest,
|
| 34 |
+
RPCStartupRequest, RPCStartupResponse,
|
| 35 |
+
RPCUProfileRequest)
|
| 36 |
+
from vllm.engine.protocol import EngineClient
|
| 37 |
+
# yapf: enable
|
| 38 |
+
from vllm.envs import VLLM_RPC_TIMEOUT
|
| 39 |
+
from vllm.inputs import PromptType
|
| 40 |
+
from vllm.inputs.preprocess import InputPreprocessor
|
| 41 |
+
from vllm.logger import init_logger
|
| 42 |
+
from vllm.lora.request import LoRARequest
|
| 43 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 44 |
+
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
| 45 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 46 |
+
from vllm.sampling_params import SamplingParams
|
| 47 |
+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
| 48 |
+
from vllm.utils import deprecate_kwargs
|
| 49 |
+
|
| 50 |
+
logger = init_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MQClientClosedError(Exception):
|
| 54 |
+
"""Exception class raised when the client is used post-close.
|
| 55 |
+
|
| 56 |
+
The client can be closed, which closes the ZMQ context. This normally
|
| 57 |
+
happens on server shutdown. In some cases, methods like abort and
|
| 58 |
+
do_log_stats will still be called and then try to open a socket, which
|
| 59 |
+
causes a ZMQError and creates a huge stack trace.
|
| 60 |
+
So, we throw this error such that we can suppress it.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MQLLMEngineClient(EngineClient):
|
| 65 |
+
"""A client wrapper for MQLLMEngine that conforms to the
|
| 66 |
+
EngineClient protocol.
|
| 67 |
+
|
| 68 |
+
MQLLMEngine and MQLLMEngineClient are intended to run in separate
|
| 69 |
+
processes communicating via zeromq ipc sockets.
|
| 70 |
+
|
| 71 |
+
The entrypoint to MQLLMEngineClient is through the generate()
|
| 72 |
+
method. On generate() MQLLMEngine does three things:
|
| 73 |
+
- Creates an asyncio output queue
|
| 74 |
+
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
|
| 75 |
+
- Pulls RequestOutputs from its queue and yields them
|
| 76 |
+
|
| 77 |
+
MQLLMEngine runs two background loops:
|
| 78 |
+
- output_loop: the output loop pulls List[RequestOutput]
|
| 79 |
+
from the MQLLMEngine via zmq (each list is the output
|
| 80 |
+
of one engine_step in the LLMEngine). It then parses
|
| 81 |
+
the list and pushes individual request_outputs into
|
| 82 |
+
the corresponding output_queue such that they can be
|
| 83 |
+
consumed by the .generate() method.
|
| 84 |
+
- health_loop: the health loop queries the health socket
|
| 85 |
+
every N seconds, confirming the engine is healthy
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, ipc_path: str, engine_config: VllmConfig,
|
| 89 |
+
engine_pid: int):
|
| 90 |
+
self.context = zmq.asyncio.Context()
|
| 91 |
+
self._errored_with: Optional[BaseException] = None
|
| 92 |
+
|
| 93 |
+
# Get the configs.
|
| 94 |
+
self.model_config = engine_config.model_config
|
| 95 |
+
self.decoding_config = engine_config.decoding_config
|
| 96 |
+
|
| 97 |
+
# Create the tokenizer group.
|
| 98 |
+
self.tokenizer = init_tokenizer_from_configs(
|
| 99 |
+
model_config=self.model_config,
|
| 100 |
+
scheduler_config=engine_config.scheduler_config,
|
| 101 |
+
parallel_config=engine_config.parallel_config,
|
| 102 |
+
lora_config=engine_config.lora_config)
|
| 103 |
+
self.input_preprocessor = InputPreprocessor(self.model_config,
|
| 104 |
+
self.tokenizer)
|
| 105 |
+
|
| 106 |
+
# Send RPCGenerateRequest to the MQLLMEngine.
|
| 107 |
+
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
| 108 |
+
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
|
| 109 |
+
|
| 110 |
+
# Receive streams of RequestOutput from the MQLLMEngine.
|
| 111 |
+
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
| 112 |
+
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
| 113 |
+
|
| 114 |
+
# IPC path for acking heartbeats.
|
| 115 |
+
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
|
| 116 |
+
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
| 117 |
+
|
| 118 |
+
# IPC path for the data socket.
|
| 119 |
+
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
| 120 |
+
|
| 121 |
+
# Stream for each individual request.
|
| 122 |
+
self.output_queues: Dict[str, asyncio.Queue] = {}
|
| 123 |
+
|
| 124 |
+
# Loop to handle output of the LLMEngine periodically.
|
| 125 |
+
# Started after the MQLLMEngine is ready so that we can
|
| 126 |
+
# build the Client in an executor to enable clean shutdown.
|
| 127 |
+
self.output_loop: Optional[asyncio.Task] = None
|
| 128 |
+
|
| 129 |
+
# Loop to check health of the LLMEngine periodically.
|
| 130 |
+
# Started after the MQLLMEngine is ready.
|
| 131 |
+
self.health_loop: Optional[asyncio.Task] = None
|
| 132 |
+
self._engine_process = psutil.Process(engine_pid)
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
| 136 |
+
# Pipeline parallel not yet supported
|
| 137 |
+
return engine_args.pipeline_parallel_size > 1
|
| 138 |
+
|
| 139 |
+
@contextmanager
|
| 140 |
+
def get_data_socket(self) -> Iterator[Socket]:
|
| 141 |
+
socket = self.context.socket(zmq.constants.DEALER)
|
| 142 |
+
try:
|
| 143 |
+
socket.connect(self.data_ipc_path)
|
| 144 |
+
yield socket
|
| 145 |
+
finally:
|
| 146 |
+
socket.close(linger=0)
|
| 147 |
+
|
| 148 |
+
async def run_heartbeat_loop(self, timeout: int):
|
| 149 |
+
"""Background loop that continually checks to ensure the engine process
|
| 150 |
+
is still alive.
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
while True:
|
| 154 |
+
# Check if the engine process is running:
|
| 155 |
+
if not self._engine_process.is_running() or (
|
| 156 |
+
self._engine_process.status() == psutil.STATUS_ZOMBIE):
|
| 157 |
+
# NB: is_running() returns True for zombies
|
| 158 |
+
self._set_errored(
|
| 159 |
+
RuntimeError(
|
| 160 |
+
f"Engine process (pid {self._engine_process.pid}) "
|
| 161 |
+
"died."))
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
if await self.heartbeat_socket.poll(timeout=timeout):
|
| 165 |
+
# Heartbeat received- check the message
|
| 166 |
+
await self._check_success(
|
| 167 |
+
error_message="Heartbeat failed.",
|
| 168 |
+
socket=self.heartbeat_socket)
|
| 169 |
+
|
| 170 |
+
logger.debug("Heartbeat successful.")
|
| 171 |
+
|
| 172 |
+
except asyncio.CancelledError:
|
| 173 |
+
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
| 174 |
+
|
| 175 |
+
except psutil.NoSuchProcess:
|
| 176 |
+
self._set_errored(
|
| 177 |
+
RuntimeError(
|
| 178 |
+
f"Engine process (pid {self._engine_process.pid}) died."))
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
self._set_errored(e)
|
| 182 |
+
|
| 183 |
+
async def run_output_handler_loop(self):
|
| 184 |
+
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
while True:
|
| 188 |
+
# Poll, checking for ENGINE_DEAD
|
| 189 |
+
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
|
| 190 |
+
) == 0:
|
| 191 |
+
logger.debug("Waiting for output from MQLLMEngine.")
|
| 192 |
+
|
| 193 |
+
# If errored, alert all running requests.
|
| 194 |
+
if self.errored:
|
| 195 |
+
for queue_j in tuple(self.output_queues.values()):
|
| 196 |
+
queue_j.put_nowait(
|
| 197 |
+
ENGINE_DEAD_ERROR(self._errored_with))
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
message: Frame = await self.output_socket.recv(copy=False)
|
| 201 |
+
request_outputs = pickle.loads(message.buffer)
|
| 202 |
+
|
| 203 |
+
is_error = isinstance(request_outputs,
|
| 204 |
+
(BaseException, RPCError))
|
| 205 |
+
if is_error:
|
| 206 |
+
if isinstance(request_outputs, RPCError):
|
| 207 |
+
rpc_error: RPCError = request_outputs
|
| 208 |
+
request_id = rpc_error.request_id
|
| 209 |
+
exception = rpc_error.exception
|
| 210 |
+
is_engine_errored = rpc_error.is_engine_errored
|
| 211 |
+
else:
|
| 212 |
+
# MPLLMEngine should always return an RPCError to
|
| 213 |
+
# the output_socket when an issue arises.
|
| 214 |
+
# If we are here, we are in a bad state and
|
| 215 |
+
# should shut down the server.
|
| 216 |
+
error: BaseException = request_outputs
|
| 217 |
+
logger.error(
|
| 218 |
+
"Received Exception %s rather than RPCError from "
|
| 219 |
+
"MPLLMEngine. This should never happen.", error)
|
| 220 |
+
request_id = None
|
| 221 |
+
exception = error
|
| 222 |
+
is_engine_errored = True
|
| 223 |
+
|
| 224 |
+
# Set to error state only on engine critical error
|
| 225 |
+
# (and record only the first one)
|
| 226 |
+
if is_engine_errored and not self._errored_with:
|
| 227 |
+
self._errored_with = exception
|
| 228 |
+
# If engine is errored, no matter the type of exception
|
| 229 |
+
# it will no longer be able to receive new requests,
|
| 230 |
+
# therefore we have to inform that the current
|
| 231 |
+
# processed requests failed as well. Send back a dead
|
| 232 |
+
# engine error give this feedback and also give a
|
| 233 |
+
# 'hint' to the server to shutdown next.
|
| 234 |
+
exception = self.dead_error
|
| 235 |
+
|
| 236 |
+
if request_id is None:
|
| 237 |
+
# If request_id is None, then the engine raised an
|
| 238 |
+
# exception for a batch, and we may not know the
|
| 239 |
+
# request that caused it, neither if it was actually
|
| 240 |
+
# caused by any of them (e.g. CUDA OOM). Therefore we
|
| 241 |
+
# broadcast the same exception for all requests.
|
| 242 |
+
for queue_i in tuple(self.output_queues.values()):
|
| 243 |
+
queue_i.put_nowait(exception)
|
| 244 |
+
else:
|
| 245 |
+
queue = self.output_queues.get(request_id)
|
| 246 |
+
if queue is not None:
|
| 247 |
+
queue.put_nowait(exception)
|
| 248 |
+
# Put each output into the appropriate queue.
|
| 249 |
+
elif isinstance(request_outputs, RPCAdapterLoadedResponse):
|
| 250 |
+
self._add_output(request_outputs)
|
| 251 |
+
else:
|
| 252 |
+
for request_output in request_outputs:
|
| 253 |
+
self._add_output(request_output)
|
| 254 |
+
|
| 255 |
+
except asyncio.CancelledError:
|
| 256 |
+
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
| 257 |
+
|
| 258 |
+
def _add_output(self, request_output: Union[RequestOutput,
|
| 259 |
+
RPCAdapterLoadedResponse]):
|
| 260 |
+
queue = self.output_queues.get(request_output.request_id)
|
| 261 |
+
if queue is not None:
|
| 262 |
+
queue.put_nowait(request_output)
|
| 263 |
+
|
| 264 |
+
async def setup(self):
|
| 265 |
+
"""Setup the client before it starts sending server requests."""
|
| 266 |
+
|
| 267 |
+
# Start output_loop
|
| 268 |
+
if self.output_loop is None:
|
| 269 |
+
# only generate once to avoid multiple concurrent output_loops
|
| 270 |
+
# this will lead to race conditions and wrong orders of tokens
|
| 271 |
+
# returned by the engine
|
| 272 |
+
# setup will be called multiple times during the startup of
|
| 273 |
+
# the engine
|
| 274 |
+
self.output_loop = asyncio.create_task(
|
| 275 |
+
self.run_output_handler_loop())
|
| 276 |
+
|
| 277 |
+
with self.get_data_socket() as socket:
|
| 278 |
+
# Wait until server is ready.
|
| 279 |
+
response = await self._wait_for_server_rpc(socket)
|
| 280 |
+
|
| 281 |
+
self.tracing_flag = response.tracing_enabled
|
| 282 |
+
|
| 283 |
+
# Start health_loop.
|
| 284 |
+
if self.health_loop is None:
|
| 285 |
+
self.health_loop = asyncio.create_task(
|
| 286 |
+
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
|
| 287 |
+
|
| 288 |
+
def close(self):
|
| 289 |
+
"""Destroy the ZeroMQ Context."""
|
| 290 |
+
# Close all sockets and terminate the context.
|
| 291 |
+
self.context.destroy(linger=0)
|
| 292 |
+
|
| 293 |
+
# Cancel background tasks.
|
| 294 |
+
if self.health_loop is not None:
|
| 295 |
+
self.health_loop.cancel()
|
| 296 |
+
if self.output_loop is not None:
|
| 297 |
+
self.output_loop.cancel()
|
| 298 |
+
|
| 299 |
+
def _set_errored(self, e: BaseException):
|
| 300 |
+
logger.exception(repr(e))
|
| 301 |
+
if self._errored_with is None:
|
| 302 |
+
self._errored_with = e
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
async def _send_get_data_rpc_request(request: RPCStartupRequest,
|
| 306 |
+
expected_type: Any,
|
| 307 |
+
error_message: str,
|
| 308 |
+
socket: Socket) -> Any:
|
| 309 |
+
"""Send an RPC request that is expecting data back."""
|
| 310 |
+
|
| 311 |
+
# Ping RPCServer with a request.
|
| 312 |
+
await socket.send_multipart((pickle.dumps(request), ), copy=False)
|
| 313 |
+
|
| 314 |
+
# Make sure the server responds in time.
|
| 315 |
+
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
| 316 |
+
raise TimeoutError("RPCServer didn't reply within "
|
| 317 |
+
f"{VLLM_RPC_TIMEOUT} ms")
|
| 318 |
+
|
| 319 |
+
# Await the data from the Server.
|
| 320 |
+
frame = await socket.recv(copy=False)
|
| 321 |
+
data = pickle.loads(frame.buffer)
|
| 322 |
+
|
| 323 |
+
if isinstance(data, BaseException):
|
| 324 |
+
raise data
|
| 325 |
+
elif not isinstance(data, expected_type):
|
| 326 |
+
raise ValueError(error_message)
|
| 327 |
+
|
| 328 |
+
return data
|
| 329 |
+
|
| 330 |
+
@staticmethod
|
| 331 |
+
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
|
| 332 |
+
socket: Socket):
|
| 333 |
+
"""Send one-way RPC request to trigger an action."""
|
| 334 |
+
|
| 335 |
+
if socket.closed:
|
| 336 |
+
raise MQClientClosedError()
|
| 337 |
+
|
| 338 |
+
await socket.send_multipart((pickle.dumps(request), ))
|
| 339 |
+
|
| 340 |
+
async def _await_ack(self, error_message: str, socket: Socket):
|
| 341 |
+
"""Await acknowledgement that a request succeeded."""
|
| 342 |
+
|
| 343 |
+
if socket.closed:
|
| 344 |
+
raise MQClientClosedError()
|
| 345 |
+
|
| 346 |
+
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
| 347 |
+
raise TimeoutError("MQLLMEngine didn't reply within "
|
| 348 |
+
f"{VLLM_RPC_TIMEOUT}ms")
|
| 349 |
+
|
| 350 |
+
await self._check_success(error_message, socket)
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
async def _check_success(error_message: str, socket: Socket):
|
| 354 |
+
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
|
| 355 |
+
|
| 356 |
+
if socket.closed:
|
| 357 |
+
raise MQClientClosedError()
|
| 358 |
+
|
| 359 |
+
frame = await socket.recv(copy=False)
|
| 360 |
+
response = pickle.loads(frame.buffer)
|
| 361 |
+
|
| 362 |
+
# Raise error if unsuccessful
|
| 363 |
+
if isinstance(response, BaseException):
|
| 364 |
+
raise response
|
| 365 |
+
elif (not isinstance(response, str)
|
| 366 |
+
or response != VLLM_RPC_SUCCESS_STR):
|
| 367 |
+
raise ValueError(error_message)
|
| 368 |
+
|
| 369 |
+
async def get_input_preprocessor(self) -> InputPreprocessor:
|
| 370 |
+
return self.input_preprocessor
|
| 371 |
+
|
| 372 |
+
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
|
| 373 |
+
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
| 374 |
+
|
| 375 |
+
async def get_decoding_config(self) -> DecodingConfig:
|
| 376 |
+
return self.decoding_config
|
| 377 |
+
|
| 378 |
+
async def get_model_config(self) -> ModelConfig:
|
| 379 |
+
return self.model_config
|
| 380 |
+
|
| 381 |
+
async def is_tracing_enabled(self) -> bool:
|
| 382 |
+
return self.tracing_flag
|
| 383 |
+
|
| 384 |
+
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
|
| 385 |
+
"""Wait for the RPCServer to start up."""
|
| 386 |
+
|
| 387 |
+
return await self._send_get_data_rpc_request(
|
| 388 |
+
request=RPCStartupRequest.IS_SERVER_READY,
|
| 389 |
+
expected_type=RPCStartupResponse,
|
| 390 |
+
error_message="Unable to start RPC Server",
|
| 391 |
+
socket=socket)
|
| 392 |
+
|
| 393 |
+
async def abort(self, request_id: str):
|
| 394 |
+
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
| 395 |
+
|
| 396 |
+
with suppress(MQClientClosedError):
|
| 397 |
+
await self._send_one_way_rpc_request(
|
| 398 |
+
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
| 399 |
+
|
| 400 |
+
async def do_log_stats(
|
| 401 |
+
self,
|
| 402 |
+
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
| 403 |
+
model_output: Optional[List[SamplerOutput]] = None,
|
| 404 |
+
) -> None:
|
| 405 |
+
"""
|
| 406 |
+
Ignore do_log_stats (handled on MQLLMEngine polling)
|
| 407 |
+
"""
|
| 408 |
+
pass
|
| 409 |
+
|
| 410 |
+
async def check_health(self):
|
| 411 |
+
"""
|
| 412 |
+
The check health loop probes the health status of the
|
| 413 |
+
Engine's health every N seconds and sets _errored_with
|
| 414 |
+
if the engine is unhealthy.
|
| 415 |
+
"""
|
| 416 |
+
if self._errored_with is not None:
|
| 417 |
+
raise self._errored_with
|
| 418 |
+
|
| 419 |
+
@property
|
| 420 |
+
def is_running(self) -> bool:
|
| 421 |
+
return not self.errored
|
| 422 |
+
|
| 423 |
+
@property
|
| 424 |
+
def is_stopped(self) -> bool:
|
| 425 |
+
return self.errored
|
| 426 |
+
|
| 427 |
+
@property
|
| 428 |
+
def errored(self) -> bool:
|
| 429 |
+
return self._errored_with is not None
|
| 430 |
+
|
| 431 |
+
@property
|
| 432 |
+
def dead_error(self) -> BaseException:
|
| 433 |
+
return ENGINE_DEAD_ERROR(self._errored_with)
|
| 434 |
+
|
| 435 |
+
@overload
|
| 436 |
+
def generate(
|
| 437 |
+
self,
|
| 438 |
+
prompt: PromptType,
|
| 439 |
+
sampling_params: SamplingParams,
|
| 440 |
+
request_id: str,
|
| 441 |
+
lora_request: Optional[LoRARequest] = None,
|
| 442 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 443 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 444 |
+
priority: int = 0,
|
| 445 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 446 |
+
...
|
| 447 |
+
|
| 448 |
+
@overload
|
| 449 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 450 |
+
def generate(
|
| 451 |
+
self,
|
| 452 |
+
*,
|
| 453 |
+
inputs: PromptType,
|
| 454 |
+
sampling_params: SamplingParams,
|
| 455 |
+
request_id: str,
|
| 456 |
+
lora_request: Optional[LoRARequest] = None,
|
| 457 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 458 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 459 |
+
priority: int = 0,
|
| 460 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 461 |
+
...
|
| 462 |
+
|
| 463 |
+
@deprecate_kwargs(
|
| 464 |
+
"inputs",
|
| 465 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 466 |
+
)
|
| 467 |
+
def generate(
|
| 468 |
+
self,
|
| 469 |
+
prompt: Optional[PromptType] = None,
|
| 470 |
+
sampling_params: Optional[SamplingParams] = None,
|
| 471 |
+
request_id: Optional[str] = None,
|
| 472 |
+
lora_request: Optional[LoRARequest] = None,
|
| 473 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 474 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 475 |
+
priority: int = 0,
|
| 476 |
+
*,
|
| 477 |
+
inputs: Optional[PromptType] = None # DEPRECATED
|
| 478 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 479 |
+
"""Generate outputs for a request.
|
| 480 |
+
|
| 481 |
+
Generate outputs for a request. This method is a coroutine. It adds the
|
| 482 |
+
request into the waiting queue of the LLMEngine and streams the outputs
|
| 483 |
+
from the LLMEngine to the caller.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
| 487 |
+
for more details about the format of each input.
|
| 488 |
+
sampling_params: The sampling parameters of the request.
|
| 489 |
+
request_id: The unique id of the request.
|
| 490 |
+
lora_request: LoRA request to use for generation, if any.
|
| 491 |
+
trace_headers: OpenTelemetry trace headers.
|
| 492 |
+
prompt_adapter_request: Prompt Adapter request to use
|
| 493 |
+
for generation, if any.
|
| 494 |
+
priority: Priority of the request (lower means earlier handling).
|
| 495 |
+
Any priority other than 0 will lead to an error if the
|
| 496 |
+
scheduling policy is not "priority".
|
| 497 |
+
"""
|
| 498 |
+
if inputs is not None:
|
| 499 |
+
prompt = inputs
|
| 500 |
+
assert (prompt is not None and sampling_params is not None
|
| 501 |
+
and request_id is not None)
|
| 502 |
+
|
| 503 |
+
return self._process_request(prompt, sampling_params, request_id,
|
| 504 |
+
lora_request, trace_headers,
|
| 505 |
+
prompt_adapter_request, priority)
|
| 506 |
+
|
| 507 |
+
@overload
|
| 508 |
+
def encode(
|
| 509 |
+
self,
|
| 510 |
+
prompt: PromptType,
|
| 511 |
+
pooling_params: PoolingParams,
|
| 512 |
+
request_id: str,
|
| 513 |
+
lora_request: Optional[LoRARequest] = None,
|
| 514 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 515 |
+
priority: int = 0,
|
| 516 |
+
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
| 517 |
+
...
|
| 518 |
+
|
| 519 |
+
@overload
|
| 520 |
+
@deprecated("'inputs' will be renamed to 'prompt")
|
| 521 |
+
def encode(
|
| 522 |
+
self,
|
| 523 |
+
*,
|
| 524 |
+
inputs: PromptType,
|
| 525 |
+
pooling_params: PoolingParams,
|
| 526 |
+
request_id: str,
|
| 527 |
+
lora_request: Optional[LoRARequest] = None,
|
| 528 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 529 |
+
priority: int = 0,
|
| 530 |
+
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
| 531 |
+
...
|
| 532 |
+
|
| 533 |
+
@deprecate_kwargs(
|
| 534 |
+
"inputs",
|
| 535 |
+
additional_message="Please use the 'prompt' parameter instead.",
|
| 536 |
+
)
|
| 537 |
+
def encode(
|
| 538 |
+
self,
|
| 539 |
+
prompt: Optional[PromptType] = None,
|
| 540 |
+
pooling_params: Optional[PoolingParams] = None,
|
| 541 |
+
request_id: Optional[str] = None,
|
| 542 |
+
lora_request: Optional[LoRARequest] = None,
|
| 543 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 544 |
+
priority: int = 0,
|
| 545 |
+
*,
|
| 546 |
+
inputs: Optional[PromptType] = None # DEPRECATED
|
| 547 |
+
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
| 548 |
+
"""Generate outputs for a request from a pooling model.
|
| 549 |
+
|
| 550 |
+
Generate outputs for a request. This method is a coroutine. It adds the
|
| 551 |
+
request into the waiting queue of the LLMEngine and streams the outputs
|
| 552 |
+
from the LLMEngine to the caller.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
| 556 |
+
for more details about the format of each input.
|
| 557 |
+
pooling_params: The pooling parameters of the request.
|
| 558 |
+
request_id: The unique id of the request.
|
| 559 |
+
lora_request: LoRA request to use for generation, if any.
|
| 560 |
+
trace_headers: OpenTelemetry trace headers.
|
| 561 |
+
|
| 562 |
+
Yields:
|
| 563 |
+
The output `PoolingRequestOutput` objects from the LLMEngine
|
| 564 |
+
for the request.
|
| 565 |
+
"""
|
| 566 |
+
if inputs is not None:
|
| 567 |
+
prompt = inputs
|
| 568 |
+
assert (prompt is not None and pooling_params is not None
|
| 569 |
+
and request_id is not None)
|
| 570 |
+
|
| 571 |
+
return cast(
|
| 572 |
+
AsyncGenerator[PoolingRequestOutput, None],
|
| 573 |
+
self._process_request(prompt,
|
| 574 |
+
pooling_params,
|
| 575 |
+
request_id,
|
| 576 |
+
lora_request,
|
| 577 |
+
trace_headers,
|
| 578 |
+
priority=priority))
|
| 579 |
+
|
| 580 |
+
async def _process_request(
|
| 581 |
+
self,
|
| 582 |
+
prompt: PromptType,
|
| 583 |
+
params: Union[SamplingParams, PoolingParams],
|
| 584 |
+
request_id: str,
|
| 585 |
+
lora_request: Optional[LoRARequest] = None,
|
| 586 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 587 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 588 |
+
priority: int = 0,
|
| 589 |
+
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
| 590 |
+
PoolingRequestOutput, None]]:
|
| 591 |
+
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
| 592 |
+
|
| 593 |
+
# If already dead, error out.
|
| 594 |
+
if self._errored_with is not None:
|
| 595 |
+
raise ENGINE_DEAD_ERROR(self._errored_with)
|
| 596 |
+
|
| 597 |
+
# Ensure the request id is unique among running requests
|
| 598 |
+
if request_id in self.output_queues:
|
| 599 |
+
raise ValueError(f"Request {request_id} already exists")
|
| 600 |
+
|
| 601 |
+
# Constructing guided decoding logits processors is expensive, so we do
|
| 602 |
+
# it here to avoid contending with cpu resources and the GIL on the
|
| 603 |
+
# backend process.
|
| 604 |
+
if isinstance(params, SamplingParams) and \
|
| 605 |
+
params.guided_decoding is not None:
|
| 606 |
+
params = await \
|
| 607 |
+
build_guided_decoding_logits_processor_async(
|
| 608 |
+
sampling_params=params,
|
| 609 |
+
tokenizer=await self.get_tokenizer(lora_request),
|
| 610 |
+
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
| 611 |
+
if self.decoding_config
|
| 612 |
+
else DecodingConfig.guided_decoding_backend),
|
| 613 |
+
model_config=self.model_config
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# 1) Create output queue for this requests.
|
| 617 |
+
queue: asyncio.Queue[Union[RequestOutput,
|
| 618 |
+
BaseException]] = asyncio.Queue()
|
| 619 |
+
self.output_queues[request_id] = queue
|
| 620 |
+
|
| 621 |
+
try:
|
| 622 |
+
# 2) Detach logits processors so that they can be pickled
|
| 623 |
+
# separately (may require cloudpickle which is slower)
|
| 624 |
+
if isinstance(params, SamplingParams) and params.logits_processors:
|
| 625 |
+
# Defensive shallow copy
|
| 626 |
+
params = copy.copy(params)
|
| 627 |
+
logits_processors = params.logits_processors
|
| 628 |
+
params.logits_processors = None
|
| 629 |
+
lp_bytes = cloudpickle.dumps(logits_processors)
|
| 630 |
+
else:
|
| 631 |
+
lp_bytes = None
|
| 632 |
+
|
| 633 |
+
request_bytes = pickle.dumps(
|
| 634 |
+
RPCProcessRequest(
|
| 635 |
+
prompt=prompt,
|
| 636 |
+
params=params,
|
| 637 |
+
request_id=request_id,
|
| 638 |
+
lora_request=lora_request,
|
| 639 |
+
trace_headers=trace_headers,
|
| 640 |
+
prompt_adapter_request=prompt_adapter_request,
|
| 641 |
+
priority=priority,
|
| 642 |
+
))
|
| 643 |
+
|
| 644 |
+
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
| 645 |
+
parts = (request_bytes,
|
| 646 |
+
lp_bytes) if lp_bytes else (request_bytes, )
|
| 647 |
+
await self.input_socket.send_multipart(parts, copy=False)
|
| 648 |
+
|
| 649 |
+
# 4) Stream the RequestOutputs from the output queue. Note
|
| 650 |
+
# that the output_loop pushes RequestOutput objects to this
|
| 651 |
+
# queue after pulling them from the zmq socket.
|
| 652 |
+
finished = False
|
| 653 |
+
try:
|
| 654 |
+
while not finished:
|
| 655 |
+
request_output = await queue.get()
|
| 656 |
+
|
| 657 |
+
if isinstance(request_output, BaseException):
|
| 658 |
+
raise request_output
|
| 659 |
+
|
| 660 |
+
finished = request_output.finished
|
| 661 |
+
yield request_output
|
| 662 |
+
finally:
|
| 663 |
+
# Request was canceled by the client.
|
| 664 |
+
if not finished and not self.errored:
|
| 665 |
+
await self.abort(request_id)
|
| 666 |
+
finally:
|
| 667 |
+
self.output_queues.pop(request_id)
|
| 668 |
+
|
| 669 |
+
async def start_profile(self) -> None:
|
| 670 |
+
"""Start profiling the engine"""
|
| 671 |
+
|
| 672 |
+
await self._send_one_way_rpc_request(
|
| 673 |
+
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
|
| 674 |
+
|
| 675 |
+
async def stop_profile(self) -> None:
|
| 676 |
+
"""Stop profiling the engine"""
|
| 677 |
+
|
| 678 |
+
await self._send_one_way_rpc_request(
|
| 679 |
+
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
| 680 |
+
|
| 681 |
+
async def reset_prefix_cache(self) -> None:
|
| 682 |
+
"""Reset the prefix cache"""
|
| 683 |
+
|
| 684 |
+
await self._send_one_way_rpc_request(
|
| 685 |
+
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
|
| 686 |
+
socket=self.input_socket)
|
| 687 |
+
|
| 688 |
+
async def add_lora(self, lora_request: LoRARequest) -> None:
|
| 689 |
+
"""Load a new LoRA adapter into the engine for future requests."""
|
| 690 |
+
# Uses the same I/O as generate requests
|
| 691 |
+
request = RPCLoadAdapterRequest(lora_request)
|
| 692 |
+
|
| 693 |
+
# Create output queue for this requests.
|
| 694 |
+
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
|
| 695 |
+
self.output_queues[request.request_id] = queue
|
| 696 |
+
|
| 697 |
+
# Send the request
|
| 698 |
+
request_bytes = pickle.dumps(request)
|
| 699 |
+
await self.input_socket.send_multipart((request_bytes, ), copy=False)
|
| 700 |
+
|
| 701 |
+
# Wait for the response
|
| 702 |
+
request_output = await queue.get()
|
| 703 |
+
self.output_queues.pop(request.request_id)
|
| 704 |
+
|
| 705 |
+
# Raise on error, otherwise happily return None
|
| 706 |
+
if isinstance(request_output, BaseException):
|
| 707 |
+
raise request_output
|
.venv/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
import signal
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from typing import Iterator, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import cloudpickle
|
| 9 |
+
import zmq
|
| 10 |
+
|
| 11 |
+
from vllm import AsyncEngineArgs, SamplingParams
|
| 12 |
+
from vllm.engine.llm_engine import LLMEngine
|
| 13 |
+
# yapf conflicts with isort for this block
|
| 14 |
+
# yapf: disable
|
| 15 |
+
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
| 16 |
+
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
| 17 |
+
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
| 18 |
+
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
| 19 |
+
RPCAdapterLoadedResponse, RPCError,
|
| 20 |
+
RPCLoadAdapterRequest,
|
| 21 |
+
RPCProcessRequest,
|
| 22 |
+
RPCResetPrefixCacheRequest,
|
| 23 |
+
RPCStartupRequest, RPCStartupResponse,
|
| 24 |
+
RPCUProfileRequest)
|
| 25 |
+
# yapf: enable
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
from vllm.outputs import RequestOutput
|
| 28 |
+
from vllm.usage.usage_lib import UsageContext
|
| 29 |
+
|
| 30 |
+
logger = init_logger(__name__)
|
| 31 |
+
|
| 32 |
+
POLLING_TIMEOUT_MS = 10000
|
| 33 |
+
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MQLLMEngine:
|
| 37 |
+
"""A multiprocessing wrapper for :class:`LLMEngine`.
|
| 38 |
+
|
| 39 |
+
This class is used to wrap the :class:`LLMEngine` class to enable use
|
| 40 |
+
in concurrnet manner. It runs a background loop and uses zeromq to
|
| 41 |
+
receive new requests and stream outputs incrementally via ipc.
|
| 42 |
+
|
| 43 |
+
The :class:`LLMEngine` generate or encode process is kicked off when a new
|
| 44 |
+
RPCProcessRequest is received by the input_socket.
|
| 45 |
+
|
| 46 |
+
The self.engine_loop checks the input_socket for new requests,
|
| 47 |
+
adds them to the LLMEngine if there are any, calls the internal
|
| 48 |
+
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
|
| 49 |
+
the output_socket.
|
| 50 |
+
|
| 51 |
+
If use_async_sockets is set, the logic associated with reading new
|
| 52 |
+
requests from the socket and sending data to the socket is passed
|
| 53 |
+
as a callback to the llm_engine, which calls the logic asynchronously
|
| 54 |
+
such that the IPC can be overlapped with the GPU.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
ipc_path: Base path for zeromq interprocess messaging
|
| 58 |
+
use_async_sockets: Whether to make send/recv async with GPU
|
| 59 |
+
log_requests: Whether to log the requests.
|
| 60 |
+
*args: Arguments for :class:`LLMEngine`.
|
| 61 |
+
**kwargs: Arguments for :class:`LLMEngine`.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self,
|
| 65 |
+
ipc_path: str,
|
| 66 |
+
use_async_sockets: bool,
|
| 67 |
+
*args,
|
| 68 |
+
log_requests: bool = True,
|
| 69 |
+
**kwargs) -> None:
|
| 70 |
+
# For MQLLMEngine, we can use cached outputs, since each new request
|
| 71 |
+
# output is immediately pickled and send over the socket, which frees
|
| 72 |
+
# the python object to be reused again.
|
| 73 |
+
kwargs['use_cached_outputs'] = True
|
| 74 |
+
|
| 75 |
+
self.engine = LLMEngine(*args, **kwargs)
|
| 76 |
+
self.log_requests = log_requests
|
| 77 |
+
|
| 78 |
+
self.use_async_sockets = use_async_sockets
|
| 79 |
+
if self.use_async_sockets:
|
| 80 |
+
self.engine.process_request_outputs_callback = \
|
| 81 |
+
self._async_socket_engine_callback
|
| 82 |
+
|
| 83 |
+
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
| 84 |
+
|
| 85 |
+
# Receive input from the client.
|
| 86 |
+
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
| 87 |
+
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
| 88 |
+
|
| 89 |
+
# Send output stream back to client.
|
| 90 |
+
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
| 91 |
+
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
| 92 |
+
|
| 93 |
+
# Send heartbeats back to client.
|
| 94 |
+
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
| 95 |
+
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
| 96 |
+
|
| 97 |
+
# IPC path for the data socket.
|
| 98 |
+
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
| 99 |
+
|
| 100 |
+
# Error state.
|
| 101 |
+
self._errored_with: Optional[BaseException] = None
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def dead_error(self) -> BaseException:
|
| 105 |
+
if self._errored_with is not None:
|
| 106 |
+
return ENGINE_DEAD_ERROR(self._errored_with)
|
| 107 |
+
else:
|
| 108 |
+
return ENGINE_DEAD_ERROR()
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
| 112 |
+
usage_context: UsageContext, ipc_path: str):
|
| 113 |
+
"""Creates an MQLLMEngine from the engine arguments."""
|
| 114 |
+
# Setup plugins for each process
|
| 115 |
+
from vllm.plugins import load_general_plugins
|
| 116 |
+
load_general_plugins()
|
| 117 |
+
|
| 118 |
+
engine_config = engine_args.create_engine_config(usage_context)
|
| 119 |
+
executor_class = LLMEngine._get_executor_cls(engine_config)
|
| 120 |
+
|
| 121 |
+
use_async_sockets = engine_config.model_config.use_async_output_proc
|
| 122 |
+
|
| 123 |
+
return cls(ipc_path=ipc_path,
|
| 124 |
+
use_async_sockets=use_async_sockets,
|
| 125 |
+
vllm_config=engine_config,
|
| 126 |
+
executor_class=executor_class,
|
| 127 |
+
log_requests=not engine_args.disable_log_requests,
|
| 128 |
+
log_stats=not engine_args.disable_log_stats,
|
| 129 |
+
usage_context=usage_context)
|
| 130 |
+
|
| 131 |
+
def start(self):
|
| 132 |
+
try:
|
| 133 |
+
try:
|
| 134 |
+
logger.debug("Starting Startup Loop.")
|
| 135 |
+
self.run_startup_loop()
|
| 136 |
+
logger.debug("Starting Engine Loop.")
|
| 137 |
+
self.run_engine_loop()
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.exception(repr(e))
|
| 140 |
+
except KeyboardInterrupt:
|
| 141 |
+
logger.debug("Shutting down MQLLMEngine.")
|
| 142 |
+
finally:
|
| 143 |
+
logger.debug("MQLLMEngine is shut down.")
|
| 144 |
+
self.cleanup()
|
| 145 |
+
|
| 146 |
+
def cleanup(self):
|
| 147 |
+
"""Cleanup zeromq state on shutdown."""
|
| 148 |
+
# Closes all sockets and destroys context.
|
| 149 |
+
self.ctx.destroy(linger=0)
|
| 150 |
+
del self.engine
|
| 151 |
+
|
| 152 |
+
@contextmanager
|
| 153 |
+
def make_data_socket(
|
| 154 |
+
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
| 155 |
+
socket = self.ctx.socket(zmq.constants.ROUTER)
|
| 156 |
+
try:
|
| 157 |
+
socket.bind(self.data_ipc_path)
|
| 158 |
+
yield socket
|
| 159 |
+
finally:
|
| 160 |
+
socket.close(linger=0)
|
| 161 |
+
|
| 162 |
+
def run_startup_loop(self) -> None:
|
| 163 |
+
"""Startup loop for sending data from Engine -> Client."""
|
| 164 |
+
|
| 165 |
+
with self.make_data_socket() as socket:
|
| 166 |
+
response: Union[RPCStartupResponse, BaseException]
|
| 167 |
+
try:
|
| 168 |
+
identity, message = socket.recv_multipart(copy=False)
|
| 169 |
+
request: RPCStartupRequest = pickle.loads(message.buffer)
|
| 170 |
+
|
| 171 |
+
# Handle the query from the Client.
|
| 172 |
+
if request == RPCStartupRequest.IS_SERVER_READY:
|
| 173 |
+
tracing_enabled = self.engine.is_tracing_enabled()
|
| 174 |
+
response = RPCStartupResponse(
|
| 175 |
+
tracing_enabled=tracing_enabled)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
response = e
|
| 179 |
+
|
| 180 |
+
socket.send_multipart((identity, pickle.dumps(response)),
|
| 181 |
+
copy=False)
|
| 182 |
+
|
| 183 |
+
def run_engine_loop(self):
|
| 184 |
+
"""Core busy loop of the LLMEngine."""
|
| 185 |
+
|
| 186 |
+
while True:
|
| 187 |
+
if not self.engine.has_unfinished_requests():
|
| 188 |
+
# Poll until there is work to do.
|
| 189 |
+
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
| 190 |
+
# When there's no work, check on engine health and send
|
| 191 |
+
# health status back to client
|
| 192 |
+
self._health_check()
|
| 193 |
+
self.engine.do_log_stats()
|
| 194 |
+
logger.debug("Waiting for new requests in engine loop.")
|
| 195 |
+
|
| 196 |
+
# Handle any input from the client.
|
| 197 |
+
self.handle_new_input()
|
| 198 |
+
|
| 199 |
+
# Engine step.
|
| 200 |
+
request_outputs = self.engine_step()
|
| 201 |
+
|
| 202 |
+
# Send request outputs (if async, done in engine_step callback).
|
| 203 |
+
if not self.use_async_sockets:
|
| 204 |
+
self._send_outputs(request_outputs)
|
| 205 |
+
|
| 206 |
+
def engine_step(self) -> List[RequestOutput]:
|
| 207 |
+
"""Engine step wrapper with error handling."""
|
| 208 |
+
try:
|
| 209 |
+
return self.engine.step()
|
| 210 |
+
except SystemExit:
|
| 211 |
+
raise
|
| 212 |
+
except BaseException as e:
|
| 213 |
+
self._set_errored(e)
|
| 214 |
+
rpc_err = RPCError(request_id=None,
|
| 215 |
+
is_engine_errored=True,
|
| 216 |
+
exception=e)
|
| 217 |
+
self._send_outputs(rpc_err)
|
| 218 |
+
raise e
|
| 219 |
+
|
| 220 |
+
def handle_new_input(self):
|
| 221 |
+
"""Handle new input from the socket"""
|
| 222 |
+
try:
|
| 223 |
+
while self.input_socket.poll(timeout=0) != 0:
|
| 224 |
+
frames = self.input_socket.recv_multipart(copy=False)
|
| 225 |
+
request = pickle.loads(frames[0].buffer)
|
| 226 |
+
|
| 227 |
+
if isinstance(request, RPCProcessRequest):
|
| 228 |
+
if len(frames) > 1:
|
| 229 |
+
# Use cloudpickle for logits processors
|
| 230 |
+
assert isinstance(request.params, SamplingParams)
|
| 231 |
+
lprocs = cloudpickle.loads(frames[1].buffer)
|
| 232 |
+
request.params.logits_processors = lprocs
|
| 233 |
+
self._handle_process_request(request)
|
| 234 |
+
elif isinstance(request, RPCAbortRequest):
|
| 235 |
+
self._handle_abort_request(request)
|
| 236 |
+
elif isinstance(request, RPCUProfileRequest):
|
| 237 |
+
if request == RPCUProfileRequest.START_PROFILE:
|
| 238 |
+
self.start_profile()
|
| 239 |
+
else:
|
| 240 |
+
self.stop_profile()
|
| 241 |
+
elif isinstance(request, RPCLoadAdapterRequest):
|
| 242 |
+
self._handle_load_adapter_request(request)
|
| 243 |
+
elif isinstance(request, RPCResetPrefixCacheRequest):
|
| 244 |
+
self.reset_prefix_cache()
|
| 245 |
+
else:
|
| 246 |
+
raise ValueError("Unknown RPCRequest Type: "
|
| 247 |
+
f"{type(request)}")
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
self._set_errored(e)
|
| 251 |
+
self._send_unhealthy(e)
|
| 252 |
+
raise e
|
| 253 |
+
|
| 254 |
+
def _handle_process_request(self, request: RPCProcessRequest):
|
| 255 |
+
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
| 256 |
+
request_id = request.request_id
|
| 257 |
+
|
| 258 |
+
if self._errored_with is not None:
|
| 259 |
+
rpc_err = RPCError(request_id=request_id,
|
| 260 |
+
is_engine_errored=True,
|
| 261 |
+
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
| 262 |
+
self._send_outputs(rpc_err)
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
self.engine.add_request(
|
| 266 |
+
request_id=request_id,
|
| 267 |
+
prompt=request.prompt,
|
| 268 |
+
params=request.params,
|
| 269 |
+
lora_request=request.lora_request,
|
| 270 |
+
trace_headers=request.trace_headers,
|
| 271 |
+
prompt_adapter_request=request.prompt_adapter_request,
|
| 272 |
+
priority=request.priority)
|
| 273 |
+
|
| 274 |
+
if self.log_requests:
|
| 275 |
+
logger.info("Added request %s.", request.request_id)
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
# We do not set self._errored = True here, since the error
|
| 279 |
+
# is due to an issue adding this request to the engine,
|
| 280 |
+
# rather than an issue with the engine itself.
|
| 281 |
+
is_errored = self._errored_with is not None
|
| 282 |
+
rpc_err = RPCError(request_id=request_id,
|
| 283 |
+
is_engine_errored=is_errored,
|
| 284 |
+
exception=e)
|
| 285 |
+
self._send_outputs(rpc_err)
|
| 286 |
+
|
| 287 |
+
# Remove request from the engine.
|
| 288 |
+
self.engine.abort_request(request_id)
|
| 289 |
+
|
| 290 |
+
def _handle_abort_request(self, request: RPCAbortRequest):
|
| 291 |
+
self.engine.abort_request(request.request_id)
|
| 292 |
+
if self.log_requests:
|
| 293 |
+
logger.info("Aborted request %s.", request.request_id)
|
| 294 |
+
|
| 295 |
+
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
|
| 296 |
+
try:
|
| 297 |
+
self.engine.add_lora(request.lora_request)
|
| 298 |
+
except BaseException as e:
|
| 299 |
+
# Send back an error if the adater fails to load
|
| 300 |
+
rpc_err = RPCError(request_id=request.request_id,
|
| 301 |
+
is_engine_errored=False,
|
| 302 |
+
exception=e)
|
| 303 |
+
self._send_outputs(rpc_err)
|
| 304 |
+
return
|
| 305 |
+
# Otherwise, send back the successful load message
|
| 306 |
+
self._send_outputs(
|
| 307 |
+
RPCAdapterLoadedResponse(request_id=request.request_id))
|
| 308 |
+
|
| 309 |
+
def _health_check(self):
|
| 310 |
+
# Send unhealthy if engine has already errored
|
| 311 |
+
if self._errored_with is not None:
|
| 312 |
+
self._send_unhealthy(self._errored_with)
|
| 313 |
+
try:
|
| 314 |
+
self.engine.check_health()
|
| 315 |
+
self._send_healthy()
|
| 316 |
+
except Exception as e:
|
| 317 |
+
self._set_errored(e)
|
| 318 |
+
self._send_unhealthy(e)
|
| 319 |
+
|
| 320 |
+
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
| 321 |
+
"""Send outputs back to the engine client. These can be:
|
| 322 |
+
- Exceptions
|
| 323 |
+
- A list of generation outputs
|
| 324 |
+
- A response from loading a lora adapter
|
| 325 |
+
"""
|
| 326 |
+
if outputs:
|
| 327 |
+
try:
|
| 328 |
+
from ray.exceptions import RayTaskError
|
| 329 |
+
|
| 330 |
+
# RayTaskError might not pickelable here. We need to unpack the
|
| 331 |
+
# underlying exception as the real exception in the output.
|
| 332 |
+
if (isinstance(outputs, RPCError)
|
| 333 |
+
and isinstance(outputs.exception, RayTaskError)):
|
| 334 |
+
outputs.exception = outputs.exception.cause
|
| 335 |
+
except ImportError:
|
| 336 |
+
pass
|
| 337 |
+
|
| 338 |
+
output_bytes = pickle.dumps(outputs)
|
| 339 |
+
self.output_socket.send_multipart((output_bytes, ), copy=False)
|
| 340 |
+
|
| 341 |
+
def _send_healthy(self):
|
| 342 |
+
"""Send HEALTHY message to RPCClient."""
|
| 343 |
+
if not self.heartbeat_socket.closed:
|
| 344 |
+
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
| 345 |
+
|
| 346 |
+
def _send_unhealthy(self, error: BaseException):
|
| 347 |
+
"""Send UNHEALTHY message to RPCClient."""
|
| 348 |
+
if not self.heartbeat_socket.closed:
|
| 349 |
+
error_bytes = pickle.dumps(error)
|
| 350 |
+
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
|
| 351 |
+
|
| 352 |
+
def _async_socket_engine_callback(self,
|
| 353 |
+
request_outputs: REQUEST_OUTPUTS_T):
|
| 354 |
+
"""Callback used by engine to make socket handling async with GPU."""
|
| 355 |
+
self._send_outputs(request_outputs)
|
| 356 |
+
self.handle_new_input()
|
| 357 |
+
|
| 358 |
+
def _set_errored(self, e: BaseException):
|
| 359 |
+
"""Log and set errored status if this is the first issue."""
|
| 360 |
+
if self._errored_with is None:
|
| 361 |
+
self._errored_with = e
|
| 362 |
+
|
| 363 |
+
def start_profile(self) -> None:
|
| 364 |
+
self.engine.start_profile()
|
| 365 |
+
|
| 366 |
+
def stop_profile(self) -> None:
|
| 367 |
+
self.engine.stop_profile()
|
| 368 |
+
|
| 369 |
+
def reset_prefix_cache(self) -> bool:
|
| 370 |
+
return self.engine.reset_prefix_cache()
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def signal_handler(*_) -> None:
|
| 374 |
+
raise KeyboardInterrupt("MQLLMEngine terminated")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
| 378 |
+
ipc_path: str, engine_alive):
|
| 379 |
+
try:
|
| 380 |
+
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
| 381 |
+
usage_context=usage_context,
|
| 382 |
+
ipc_path=ipc_path)
|
| 383 |
+
|
| 384 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
| 385 |
+
|
| 386 |
+
engine.start()
|
| 387 |
+
|
| 388 |
+
except BaseException as e:
|
| 389 |
+
logger.exception(e)
|
| 390 |
+
engine_alive.value = False
|
| 391 |
+
raise e
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/interfaces.cpython-311.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/multi_step.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/single_step.cpython-311.pyc
ADDED
|
Binary file (7.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/stop_checker.cpython-311.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/interfaces.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Callable, List
|
| 5 |
+
|
| 6 |
+
from vllm.config import SchedulerConfig
|
| 7 |
+
from vllm.core.scheduler import Scheduler
|
| 8 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
| 9 |
+
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
| 10 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
| 11 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 12 |
+
from vllm.utils import Counter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SequenceGroupOutputProcessor(ABC):
|
| 16 |
+
"""Interface for logic that processes new token ids in sequence groups,
|
| 17 |
+
managing detokenization, stop checking, and freeing/forking sequences with
|
| 18 |
+
the scheduler.
|
| 19 |
+
|
| 20 |
+
This is highly coupled with the LLMEngine and should be seen as an extension
|
| 21 |
+
of it. The logic is separated to simplify the LLMEngine class and allow
|
| 22 |
+
separate implementations for single-step decoding (which supports beam
|
| 23 |
+
search sequence forking) and multi-step decoding (which does not support
|
| 24 |
+
beam search, but does support speculative decoding).
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def create_output_processor(
|
| 29 |
+
scheduler_config: SchedulerConfig,
|
| 30 |
+
detokenizer: Detokenizer,
|
| 31 |
+
scheduler: List[Scheduler],
|
| 32 |
+
seq_counter: Counter,
|
| 33 |
+
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
| 34 |
+
stop_checker: "StopChecker",
|
| 35 |
+
):
|
| 36 |
+
"""Create an output processor.
|
| 37 |
+
|
| 38 |
+
This returns a single-step output processor if num_lookahead_slots is
|
| 39 |
+
zero, else returns a multi-step output processor.
|
| 40 |
+
"""
|
| 41 |
+
if scheduler_config.num_lookahead_slots == 0:
|
| 42 |
+
# Importing here to avoid cycle.
|
| 43 |
+
from vllm.engine.output_processor.single_step import (
|
| 44 |
+
SingleStepOutputProcessor)
|
| 45 |
+
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
| 46 |
+
scheduler, seq_counter,
|
| 47 |
+
stop_checker)
|
| 48 |
+
else:
|
| 49 |
+
# Importing here to avoid cycle.
|
| 50 |
+
from vllm.engine.output_processor.multi_step import (
|
| 51 |
+
MultiStepOutputProcessor)
|
| 52 |
+
return MultiStepOutputProcessor(
|
| 53 |
+
detokenizer,
|
| 54 |
+
scheduler,
|
| 55 |
+
seq_counter,
|
| 56 |
+
get_tokenizer_for_seq,
|
| 57 |
+
stop_checker,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def process_outputs(self, sequence_group: SequenceGroup,
|
| 62 |
+
outputs: List[SequenceGroupOutput],
|
| 63 |
+
is_async: bool) -> None:
|
| 64 |
+
"""Process new token ids for the sequence group. Handles logic such as
|
| 65 |
+
detokenization, stop checking, and freeing/forking sequences in the
|
| 66 |
+
scheduler.
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
| 72 |
+
outputs: List[SequenceGroupOutput]) -> None:
|
| 73 |
+
"""Update prompt logprobs received from outputs to seq_group."""
|
| 74 |
+
pass
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/multi_step.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Callable, List, cast
|
| 5 |
+
|
| 6 |
+
from vllm.core.scheduler import Scheduler
|
| 7 |
+
from vllm.engine.output_processor.interfaces import (
|
| 8 |
+
SequenceGroupOutputProcessor)
|
| 9 |
+
from vllm.engine.output_processor.single_step import (
|
| 10 |
+
single_step_process_prompt_logprob)
|
| 11 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.sampling_params import SamplingParams
|
| 14 |
+
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
| 15 |
+
CompletionSequenceGroupOutput, Sequence,
|
| 16 |
+
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
| 17 |
+
SequenceStatus)
|
| 18 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
| 19 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 20 |
+
from vllm.utils import Counter
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
| 26 |
+
"""SequenceGroupOutputProcessor which handles logic related to
|
| 27 |
+
detokenization and stopping conditions. It specializes to "multi-step
|
| 28 |
+
decoding", where vLLM's worker may generate multiple tokens per invocation.
|
| 29 |
+
This is currently mutually exclusive with advanced sampling techniques like
|
| 30 |
+
beam search, which motivates the separation of this logic from the single
|
| 31 |
+
step output processor.
|
| 32 |
+
|
| 33 |
+
This class is responsible for things such as correctly appending all new
|
| 34 |
+
token ids to their sequence, detokenizing new token ids, truncating new
|
| 35 |
+
output tokens after an eos token, and correctly handling the case where the
|
| 36 |
+
number of new output tokens per sequence differs in a single batch.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
detokenizer: Detokenizer,
|
| 42 |
+
scheduler: List[Scheduler],
|
| 43 |
+
seq_counter: Counter,
|
| 44 |
+
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
| 45 |
+
stop_checker: StopChecker,
|
| 46 |
+
):
|
| 47 |
+
self.detokenizer = detokenizer
|
| 48 |
+
self.scheduler = scheduler
|
| 49 |
+
self.seq_counter = seq_counter
|
| 50 |
+
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
| 51 |
+
self.stop_checker = stop_checker
|
| 52 |
+
|
| 53 |
+
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
| 54 |
+
outputs: List[SequenceGroupOutput]) -> None:
|
| 55 |
+
"""Process prompt logprobs associated with each step of a multi-step-
|
| 56 |
+
scheduled computation.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
seq_group: the outputs are associated with this :class:`SequenceGroup`
|
| 60 |
+
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
|
| 61 |
+
"""
|
| 62 |
+
for output in outputs:
|
| 63 |
+
# Concatenate single-step prompt logprob processing results.
|
| 64 |
+
assert isinstance(output, CompletionSequenceGroupOutput)
|
| 65 |
+
single_step_process_prompt_logprob(self, seq_group, output)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
@functools.lru_cache
|
| 69 |
+
def _log_prompt_logprob_unsupported_warning_once():
|
| 70 |
+
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
| 71 |
+
# If the feature combo become valid
|
| 72 |
+
logger.warning(
|
| 73 |
+
"Prompt logprob is not supported by multi step workers. "
|
| 74 |
+
"(e.g., speculative decode uses multi step workers).")
|
| 75 |
+
|
| 76 |
+
def process_outputs(self,
|
| 77 |
+
sequence_group: SequenceGroup,
|
| 78 |
+
outputs: List[SequenceGroupOutput],
|
| 79 |
+
is_async: bool = False) -> None:
|
| 80 |
+
"""Append new tokens in the outputs to sequences in the sequence group.
|
| 81 |
+
|
| 82 |
+
This only supports sequence groups of size 1. It supports greater than
|
| 83 |
+
one new token per sequence.
|
| 84 |
+
|
| 85 |
+
This applies logic like stop condition checking and detokenization.
|
| 86 |
+
It also handles cases where there are tokens emitted after
|
| 87 |
+
the EOS token.
|
| 88 |
+
|
| 89 |
+
is_async - Indicates whether this postprocessor runs in
|
| 90 |
+
parallel with the GPU forward pass and is processing
|
| 91 |
+
tokens from the previous step. If this is true, then
|
| 92 |
+
no tokens need to be appended since it is already done
|
| 93 |
+
externally (before the next schedule() call)
|
| 94 |
+
"""
|
| 95 |
+
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
| 96 |
+
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
| 97 |
+
# if a client disconnects from the api server.
|
| 98 |
+
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
|
| 99 |
+
if seqs is None:
|
| 100 |
+
seqs = sequence_group.get_seqs(
|
| 101 |
+
status=SequenceStatus.FINISHED_ABORTED)
|
| 102 |
+
|
| 103 |
+
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
| 104 |
+
assert len(seqs) == 1, (
|
| 105 |
+
"Beam search not supported in multi-step decoding.")
|
| 106 |
+
seq = seqs[0]
|
| 107 |
+
seq_id = seq.seq_id
|
| 108 |
+
# This method is defined in the more generic
|
| 109 |
+
# SequenceGroupOutputProcessor, but here we assume that the outputs are
|
| 110 |
+
# of a more specific type.
|
| 111 |
+
assert all([
|
| 112 |
+
isinstance(output, CompletionSequenceGroupOutput)
|
| 113 |
+
for output in outputs
|
| 114 |
+
])
|
| 115 |
+
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
|
| 116 |
+
assert all([
|
| 117 |
+
seq_id == output.samples[0].parent_seq_id
|
| 118 |
+
for output in compl_outputs
|
| 119 |
+
])
|
| 120 |
+
|
| 121 |
+
if is_async:
|
| 122 |
+
# Async case: We process tokens one by one. Here, we know the token
|
| 123 |
+
# was already appended, so we only need to do the rest of the
|
| 124 |
+
# postprocessor: Detokenization + stopping logic
|
| 125 |
+
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
| 126 |
+
else:
|
| 127 |
+
# Standard multi-step case
|
| 128 |
+
|
| 129 |
+
# Since there's only one sequence per sequence group,
|
| 130 |
+
# we can take the first sample.
|
| 131 |
+
samples = [output.samples[0] for output in compl_outputs]
|
| 132 |
+
|
| 133 |
+
# entries in sample tokens may be invalid (eg. due to spec decode
|
| 134 |
+
# rejecting tokens).
|
| 135 |
+
valid_samples = [
|
| 136 |
+
sample for sample in samples
|
| 137 |
+
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
# When both spec-decode and pre-fill chunking are enabled, we
|
| 141 |
+
# don't have guaranteed samples here (e.g. all -1s).
|
| 142 |
+
if valid_samples:
|
| 143 |
+
self._process_seq_outputs(seq, valid_samples,
|
| 144 |
+
sequence_group.sampling_params)
|
| 145 |
+
|
| 146 |
+
def _process_decode_and_stop(self, seq: Sequence,
|
| 147 |
+
sampling_params: SamplingParams) -> None:
|
| 148 |
+
new_char_count = 0
|
| 149 |
+
if sampling_params.detokenize and self.detokenizer:
|
| 150 |
+
new_char_count = self.detokenizer.decode_sequence_inplace(
|
| 151 |
+
seq, sampling_params)
|
| 152 |
+
|
| 153 |
+
# TODO(sang): Support lora.
|
| 154 |
+
self.stop_checker.maybe_stop_sequence(
|
| 155 |
+
seq,
|
| 156 |
+
new_char_count=new_char_count,
|
| 157 |
+
sampling_params=sampling_params,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def _process_seq_outputs(self, seq: Sequence,
|
| 161 |
+
valid_samples: List[SequenceOutput],
|
| 162 |
+
sampling_params: SamplingParams) -> None:
|
| 163 |
+
output_token_ids = [sample.output_token for sample in valid_samples]
|
| 164 |
+
output_logprobs = [sample.logprobs for sample in valid_samples]
|
| 165 |
+
|
| 166 |
+
# Truncate to max_tokens if necessary.
|
| 167 |
+
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
| 168 |
+
len(output_token_ids))
|
| 169 |
+
if remaining_tokens < 0:
|
| 170 |
+
output_token_ids = output_token_ids[:remaining_tokens]
|
| 171 |
+
|
| 172 |
+
# Truncate any tokens after EOS. This is required as spec decode
|
| 173 |
+
# generates a fixed number of tokens without evaluating stopping
|
| 174 |
+
# conditions within the block. This can cause an eos token to be
|
| 175 |
+
# unintentionally ignored.
|
| 176 |
+
if not sampling_params.ignore_eos:
|
| 177 |
+
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
|
| 178 |
+
# Avoiding .index calls as exception throwing in the happy path
|
| 179 |
+
# is expensive.
|
| 180 |
+
for i in range(len(output_token_ids)):
|
| 181 |
+
if output_token_ids[i] == eos_token_id:
|
| 182 |
+
output_token_ids = output_token_ids[:i + 1]
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
| 186 |
+
# Incrementally append tokens to the sequence, as if we had only one new
|
| 187 |
+
# token.
|
| 188 |
+
for output_token_id, output_logprob in zip(output_token_ids,
|
| 189 |
+
output_logprobs):
|
| 190 |
+
seq.append_token_id(
|
| 191 |
+
token_id=output_token_id,
|
| 192 |
+
logprobs=output_logprob,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if is_prefill_sampled_token:
|
| 196 |
+
is_prefill_sampled_token = False
|
| 197 |
+
else:
|
| 198 |
+
# Update num_computed_tokens iff the sampled token is not from
|
| 199 |
+
# a prefill step.
|
| 200 |
+
seq.data.update_num_computed_tokens(1)
|
| 201 |
+
|
| 202 |
+
self._process_decode_and_stop(seq, sampling_params)
|
| 203 |
+
|
| 204 |
+
if seq.is_finished():
|
| 205 |
+
break
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/single_step.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from vllm.config import SchedulerConfig
|
| 6 |
+
from vllm.core.scheduler import Scheduler
|
| 7 |
+
from vllm.engine.output_processor.interfaces import (
|
| 8 |
+
SequenceGroupOutputProcessor)
|
| 9 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
| 10 |
+
from vllm.logger import init_logger
|
| 11 |
+
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
|
| 12 |
+
SequenceGroupOutput)
|
| 13 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
| 14 |
+
from vllm.utils import Counter
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def single_step_process_prompt_logprob(
|
| 20 |
+
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
|
| 21 |
+
output: CompletionSequenceGroupOutput) -> None:
|
| 22 |
+
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
|
| 23 |
+
for a given step.
|
| 24 |
+
|
| 25 |
+
Do nothing if the output has no prompt logprobs.
|
| 26 |
+
|
| 27 |
+
Account for the fact that transformers do not compute first-token logprobs.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
|
| 31 |
+
seq_group: the output is associated with this :class:`SequenceGroup`
|
| 32 |
+
output: the :class:`SequenceGroupOutput` for a single scheduler step
|
| 33 |
+
"""
|
| 34 |
+
prompt_logprobs = output.prompt_logprobs
|
| 35 |
+
|
| 36 |
+
# If this is the first (or only) "chunk" of the prefill, we need
|
| 37 |
+
# to prepend None to the list of prompt logprobs. The reason for this
|
| 38 |
+
# is that for N prompt tokens, the Sampler will generate N-1 total
|
| 39 |
+
# prompt logprobs during prefill since the token at idx 0 will not
|
| 40 |
+
# have a logprob associated with it.
|
| 41 |
+
if prompt_logprobs is not None:
|
| 42 |
+
if not seq_group.prompt_logprobs:
|
| 43 |
+
prompt_logprobs = [None] + prompt_logprobs
|
| 44 |
+
seq_group.prompt_logprobs = []
|
| 45 |
+
|
| 46 |
+
assert hasattr(sg_output_proc, 'detokenizer')
|
| 47 |
+
if (seq_group.sampling_params.detokenize
|
| 48 |
+
and sg_output_proc.detokenizer):
|
| 49 |
+
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
|
| 50 |
+
seq_group,
|
| 51 |
+
prompt_logprobs,
|
| 52 |
+
position_offset=len(seq_group.prompt_logprobs))
|
| 53 |
+
|
| 54 |
+
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
| 58 |
+
"""SequenceGroupOutputProcessor which handles "output processing" logic,
|
| 59 |
+
which happens after the model returns generated token ids and before
|
| 60 |
+
scheduling of the next batch. Output processing logic includes
|
| 61 |
+
detokenization, and determining if a sequence is finished (e.g. via max len
|
| 62 |
+
or eos token).
|
| 63 |
+
|
| 64 |
+
The SingleStepOutputProcessor is specialized to the case where the model
|
| 65 |
+
emits at most a single token per invocation, which precludes configurations
|
| 66 |
+
such as speculative decoding or multi-step decoding. This enables beam
|
| 67 |
+
search sampling, which requires forking/finishing/freeing sequences in a way
|
| 68 |
+
that is currently difficult to schedule multiple steps ahead of time.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, scheduler_config: SchedulerConfig,
|
| 72 |
+
detokenizer: Detokenizer, scheduler: List[Scheduler],
|
| 73 |
+
seq_counter: Counter, stop_checker: StopChecker):
|
| 74 |
+
self.scheduler_config = scheduler_config
|
| 75 |
+
self.detokenizer = detokenizer
|
| 76 |
+
self.scheduler = scheduler
|
| 77 |
+
self.seq_counter = seq_counter
|
| 78 |
+
self.stop_checker = stop_checker
|
| 79 |
+
|
| 80 |
+
def process_outputs(self, sequence_group: SequenceGroup,
|
| 81 |
+
outputs: List[SequenceGroupOutput],
|
| 82 |
+
is_async: bool) -> None:
|
| 83 |
+
"""Append all new tokens to sequences in the sequence group. Fork any
|
| 84 |
+
surviving beam candidates; free any unsurviving ones.
|
| 85 |
+
|
| 86 |
+
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
| 87 |
+
as finished if they meet stop conditions.
|
| 88 |
+
|
| 89 |
+
is_async - Indicates whether this postprocessor runs in
|
| 90 |
+
parallel with the GPU forward pass and is processing
|
| 91 |
+
tokens from the previous step. If this is true, then
|
| 92 |
+
no tokens need to be appended since it is already done
|
| 93 |
+
externally (before the next schedule() call)
|
| 94 |
+
"""
|
| 95 |
+
assert (len(outputs) == 1
|
| 96 |
+
), f"{type(self)} does not support multiple outputs per step"
|
| 97 |
+
return self._process_sequence_group_outputs(sequence_group, outputs[0],
|
| 98 |
+
is_async)
|
| 99 |
+
|
| 100 |
+
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
| 101 |
+
outputs: List[SequenceGroupOutput]) -> None:
|
| 102 |
+
"""Process prompt logprobs associated with one step of a single-step-
|
| 103 |
+
scheduled computation.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
seq_group: the output is associated with this :class:`SequenceGroup`
|
| 107 |
+
outputs: the :class:`SequenceGroupOutput` for a single scheduler step
|
| 108 |
+
"""
|
| 109 |
+
assert len(outputs) == 1, "Single step should only have 1 output."
|
| 110 |
+
output = outputs[0]
|
| 111 |
+
assert isinstance(output, CompletionSequenceGroupOutput)
|
| 112 |
+
single_step_process_prompt_logprob(self, seq_group, output)
|
| 113 |
+
|
| 114 |
+
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
| 115 |
+
outputs: SequenceGroupOutput,
|
| 116 |
+
is_async: bool) -> None:
|
| 117 |
+
sampling_params = seq_group.sampling_params
|
| 118 |
+
|
| 119 |
+
sample = outputs.samples[0]
|
| 120 |
+
seq = seq_group.first_seq
|
| 121 |
+
if not is_async:
|
| 122 |
+
seq.append_token_id(sample.output_token, sample.logprobs)
|
| 123 |
+
if sampling_params.detokenize and self.detokenizer:
|
| 124 |
+
new_char_count = self.detokenizer.decode_sequence_inplace(
|
| 125 |
+
seq, sampling_params)
|
| 126 |
+
else:
|
| 127 |
+
new_char_count = 0
|
| 128 |
+
self.stop_checker.maybe_stop_sequence(
|
| 129 |
+
seq,
|
| 130 |
+
new_char_count,
|
| 131 |
+
sampling_params,
|
| 132 |
+
lora_req=seq_group.lora_request,
|
| 133 |
+
)
|
| 134 |
+
if seq.is_finished():
|
| 135 |
+
for scheduler in self.scheduler:
|
| 136 |
+
scheduler.free_seq(seq)
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/stop_checker.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Callable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from vllm.lora.request import LoRARequest
|
| 6 |
+
from vllm.sampling_params import SamplingParams
|
| 7 |
+
from vllm.sequence import Sequence, SequenceStatus
|
| 8 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class StopChecker:
|
| 12 |
+
"""LLMEngine helper class which separates out the logic involving stop
|
| 13 |
+
checking. This checks things such as: whether the eos token was emitted,
|
| 14 |
+
whether the max_tokens has been consumed, whether a stop string has been
|
| 15 |
+
emitted, or if we have exceeded the max model len.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, max_model_len: int,
|
| 19 |
+
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
|
| 20 |
+
# Do not use it directly, but use `self._get_max_model_len`.
|
| 21 |
+
self._max_model_len = max_model_len
|
| 22 |
+
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
| 23 |
+
|
| 24 |
+
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
| 25 |
+
if lora_req and lora_req.long_lora_max_len:
|
| 26 |
+
return lora_req.long_lora_max_len
|
| 27 |
+
else:
|
| 28 |
+
return self._max_model_len
|
| 29 |
+
|
| 30 |
+
def maybe_stop_sequence(
|
| 31 |
+
self,
|
| 32 |
+
seq: Sequence,
|
| 33 |
+
new_char_count: int,
|
| 34 |
+
sampling_params: SamplingParams,
|
| 35 |
+
lora_req: Optional[LoRARequest] = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""Stop the finished sequences.
|
| 38 |
+
|
| 39 |
+
new_char_count is the number of chars added to the
|
| 40 |
+
sequence's output text for the newly generated token
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
# Check if the minimum number of tokens has been generated yet;
|
| 44 |
+
# skip the stop string/token checks if not
|
| 45 |
+
if seq.get_output_len() < sampling_params.min_tokens:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# Check if the sequence has generated the EOS token.
|
| 49 |
+
if ((not sampling_params.ignore_eos)
|
| 50 |
+
and seq.get_last_token_id() == seq.eos_token_id):
|
| 51 |
+
# Remove the last EOS token unless explicitly specified
|
| 52 |
+
# This prevents unintended exposure of the EOS token
|
| 53 |
+
if new_char_count and (
|
| 54 |
+
not sampling_params.include_stop_str_in_output):
|
| 55 |
+
seq.output_text = seq.output_text[:-new_char_count]
|
| 56 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
# Check if a stop token was encountered.
|
| 60 |
+
# This assumes a single token produced per step.
|
| 61 |
+
last_token_id = seq.get_last_token_id()
|
| 62 |
+
if last_token_id in (sampling_params.stop_token_ids or ()):
|
| 63 |
+
if new_char_count and (
|
| 64 |
+
not sampling_params.include_stop_str_in_output):
|
| 65 |
+
# Remove last token
|
| 66 |
+
seq.output_text = seq.output_text[:-new_char_count]
|
| 67 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
| 68 |
+
seq.stop_reason = last_token_id
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
# Check if any stop strings are matched.
|
| 72 |
+
stop = self.check_stop_strings(
|
| 73 |
+
seq.output_text, new_char_count, sampling_params.stop,
|
| 74 |
+
sampling_params.include_stop_str_in_output)
|
| 75 |
+
if stop is not None:
|
| 76 |
+
stop_str, truncate_to = stop
|
| 77 |
+
if truncate_to != -1:
|
| 78 |
+
seq.output_text = seq.output_text[:truncate_to]
|
| 79 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
| 80 |
+
seq.stop_reason = stop_str
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
# Check if the sequence has reached max_model_len.
|
| 84 |
+
if seq.get_len() > self._get_max_model_len(lora_req):
|
| 85 |
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
# Check if the sequence has reached max_tokens.
|
| 89 |
+
if seq.get_output_len() == sampling_params.max_tokens:
|
| 90 |
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def check_stop_strings(
|
| 95 |
+
output_text: str,
|
| 96 |
+
new_char_count: int,
|
| 97 |
+
stop: List[str],
|
| 98 |
+
include_in_output: bool,
|
| 99 |
+
) -> Optional[Tuple[str, int]]:
|
| 100 |
+
"""Check if any stop strings are matched and truncate sequence
|
| 101 |
+
output text accordingly.
|
| 102 |
+
|
| 103 |
+
Returns tuple (stop_string, offset) if matched or else None.
|
| 104 |
+
|
| 105 |
+
Where stop_string is the matched stop string and offset is the
|
| 106 |
+
length to which output_text should be truncated, or -1 for no
|
| 107 |
+
truncation.
|
| 108 |
+
"""
|
| 109 |
+
if not new_char_count or not stop:
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
for stop_str in stop:
|
| 113 |
+
stop_string_len = len(stop_str)
|
| 114 |
+
# Avoid searching already-searched text.
|
| 115 |
+
stop_index = output_text.find(stop_str,
|
| 116 |
+
-new_char_count - stop_string_len)
|
| 117 |
+
if stop_index == -1:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
if include_in_output:
|
| 121 |
+
# Truncate to end of stop string.
|
| 122 |
+
stop_index += stop_string_len
|
| 123 |
+
if stop_index >= len(output_text):
|
| 124 |
+
# No truncation required.
|
| 125 |
+
return stop_str, -1
|
| 126 |
+
|
| 127 |
+
# Truncate the output text to either the beginning
|
| 128 |
+
# or end of the stop string.
|
| 129 |
+
return stop_str, stop_index
|
| 130 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/engine/output_processor/util.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Sequence as GenericSequence
|
| 5 |
+
from typing import cast
|
| 6 |
+
|
| 7 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 8 |
+
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_output_by_sequence_group(
|
| 12 |
+
outputs: GenericSequence[SamplerOutput],
|
| 13 |
+
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
| 14 |
+
"""Helper method which transforms a 2d list organized by
|
| 15 |
+
[step][sequence group] into [sequence group][step].
|
| 16 |
+
"""
|
| 17 |
+
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
|
| 18 |
+
[] for _ in range(num_seq_groups)
|
| 19 |
+
]
|
| 20 |
+
for step in outputs:
|
| 21 |
+
sequence_group_output: CompletionSequenceGroupOutput
|
| 22 |
+
for i, sequence_group_output in enumerate(step):
|
| 23 |
+
output_by_sequence_group[i].append(sequence_group_output)
|
| 24 |
+
|
| 25 |
+
# Cast to the more generic type that CompletionSequenceGroupOutput
|
| 26 |
+
# inherits from.
|
| 27 |
+
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)
|
.venv/lib/python3.11/site-packages/vllm/engine/protocol.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import AsyncGenerator, List, Mapping, Optional
|
| 6 |
+
|
| 7 |
+
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
| 8 |
+
from vllm.config import DecodingConfig, ModelConfig
|
| 9 |
+
from vllm.core.scheduler import SchedulerOutputs
|
| 10 |
+
from vllm.inputs.data import PromptType, TokensPrompt
|
| 11 |
+
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
| 12 |
+
from vllm.inputs.preprocess import InputPreprocessor
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
from vllm.lora.request import LoRARequest
|
| 15 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 16 |
+
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
| 17 |
+
from vllm.pooling_params import PoolingParams
|
| 18 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 19 |
+
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
| 20 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 21 |
+
from vllm.utils import collect_from_async_generator, random_uuid
|
| 22 |
+
|
| 23 |
+
logger = init_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EngineClient(ABC):
|
| 27 |
+
"""Protocol class for Clients to Engine"""
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def is_running(self) -> bool:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def is_stopped(self) -> bool:
|
| 37 |
+
...
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def errored(self) -> bool:
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def dead_error(self) -> BaseException:
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def generate(
|
| 51 |
+
self,
|
| 52 |
+
prompt: PromptType,
|
| 53 |
+
sampling_params: SamplingParams,
|
| 54 |
+
request_id: str,
|
| 55 |
+
lora_request: Optional[LoRARequest] = None,
|
| 56 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 57 |
+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
| 58 |
+
priority: int = 0,
|
| 59 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 60 |
+
"""Generate outputs for a request."""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
async def beam_search(
|
| 64 |
+
self,
|
| 65 |
+
prompt: PromptType,
|
| 66 |
+
request_id: str,
|
| 67 |
+
params: BeamSearchParams,
|
| 68 |
+
) -> AsyncGenerator[RequestOutput, None]:
|
| 69 |
+
|
| 70 |
+
beam_width = params.beam_width
|
| 71 |
+
max_tokens = params.max_tokens
|
| 72 |
+
ignore_eos = params.ignore_eos
|
| 73 |
+
temperature = params.temperature
|
| 74 |
+
length_penalty = params.length_penalty
|
| 75 |
+
include_stop_str_in_output = params.include_stop_str_in_output
|
| 76 |
+
|
| 77 |
+
preprocessor = await self.get_input_preprocessor()
|
| 78 |
+
tokenizer_group = preprocessor.get_tokenizer_group()
|
| 79 |
+
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
|
| 80 |
+
|
| 81 |
+
if is_explicit_encoder_decoder_prompt(prompt):
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
else:
|
| 84 |
+
processed_inputs = preprocessor._prompt_to_llm_inputs(
|
| 85 |
+
prompt,
|
| 86 |
+
request_id=request_id,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
| 90 |
+
prompt_text = processed_inputs.get("prompt")
|
| 91 |
+
multi_modal_data = processed_inputs.get("multi_modal_data")
|
| 92 |
+
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
|
| 93 |
+
|
| 94 |
+
tokenized_length = len(prompt_token_ids)
|
| 95 |
+
|
| 96 |
+
sort_beams_key = create_sort_beams_key_function(
|
| 97 |
+
tokenizer.eos_token_id, length_penalty)
|
| 98 |
+
|
| 99 |
+
beam_search_params = SamplingParams(
|
| 100 |
+
logprobs=2 * beam_width,
|
| 101 |
+
max_tokens=1,
|
| 102 |
+
temperature=temperature,
|
| 103 |
+
)
|
| 104 |
+
all_beams = [
|
| 105 |
+
BeamSearchSequence(tokens=prompt_token_ids,
|
| 106 |
+
cum_logprob=0,
|
| 107 |
+
logprobs=[],
|
| 108 |
+
multi_modal_data=multi_modal_data,
|
| 109 |
+
mm_processor_kwargs=mm_processor_kwargs)
|
| 110 |
+
]
|
| 111 |
+
completed = []
|
| 112 |
+
|
| 113 |
+
for _ in range(max_tokens):
|
| 114 |
+
prompts_batch = [
|
| 115 |
+
TokensPrompt(prompt_token_ids=beam.tokens,
|
| 116 |
+
multi_modal_data=beam.multi_modal_data,
|
| 117 |
+
mm_processor_kwargs=beam.mm_processor_kwargs)
|
| 118 |
+
for beam in all_beams
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
tasks = []
|
| 122 |
+
|
| 123 |
+
request_id = f"beam_search-{random_uuid()}"
|
| 124 |
+
for i, individual_prompt in enumerate(prompts_batch):
|
| 125 |
+
request_id_item = f"{request_id}-{i}"
|
| 126 |
+
task = asyncio.create_task(
|
| 127 |
+
collect_from_async_generator(
|
| 128 |
+
self.generate(individual_prompt, beam_search_params,
|
| 129 |
+
request_id_item)))
|
| 130 |
+
tasks.append(task)
|
| 131 |
+
|
| 132 |
+
output = await asyncio.gather(*tasks)
|
| 133 |
+
|
| 134 |
+
output = [x[0] for x in output]
|
| 135 |
+
|
| 136 |
+
new_beams = []
|
| 137 |
+
for i, current_beam in enumerate(all_beams):
|
| 138 |
+
result = output[i]
|
| 139 |
+
|
| 140 |
+
if result.outputs[0].logprobs is not None:
|
| 141 |
+
logprobs = result.outputs[0].logprobs[0]
|
| 142 |
+
for token_id, logprob_obj in logprobs.items():
|
| 143 |
+
if token_id == tokenizer.eos_token_id and \
|
| 144 |
+
not ignore_eos:
|
| 145 |
+
completed.append(
|
| 146 |
+
BeamSearchSequence(
|
| 147 |
+
tokens=current_beam.tokens +
|
| 148 |
+
[token_id] if include_stop_str_in_output
|
| 149 |
+
else current_beam.tokens,
|
| 150 |
+
logprobs=current_beam.logprobs +
|
| 151 |
+
[logprobs],
|
| 152 |
+
cum_logprob=current_beam.cum_logprob +
|
| 153 |
+
logprob_obj.logprob,
|
| 154 |
+
finish_reason="stop",
|
| 155 |
+
stop_reason=tokenizer.eos_token_id))
|
| 156 |
+
else:
|
| 157 |
+
new_beams.append(
|
| 158 |
+
BeamSearchSequence(
|
| 159 |
+
tokens=current_beam.tokens + [token_id],
|
| 160 |
+
logprobs=current_beam.logprobs +
|
| 161 |
+
[logprobs],
|
| 162 |
+
cum_logprob=current_beam.cum_logprob +
|
| 163 |
+
logprob_obj.logprob,
|
| 164 |
+
multi_modal_data=current_beam.
|
| 165 |
+
multi_modal_data,
|
| 166 |
+
mm_processor_kwargs=current_beam.
|
| 167 |
+
mm_processor_kwargs))
|
| 168 |
+
|
| 169 |
+
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
| 170 |
+
all_beams = sorted_beams[:beam_width]
|
| 171 |
+
|
| 172 |
+
completed.extend(all_beams)
|
| 173 |
+
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
| 174 |
+
best_beams = sorted_completed[:beam_width]
|
| 175 |
+
|
| 176 |
+
for beam in best_beams:
|
| 177 |
+
if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
|
| 178 |
+
# Skip the eos token in the text.
|
| 179 |
+
tokens = beam.tokens[tokenized_length:-1]
|
| 180 |
+
else:
|
| 181 |
+
tokens = beam.tokens[tokenized_length:]
|
| 182 |
+
beam.text = tokenizer.decode(tokens)
|
| 183 |
+
|
| 184 |
+
beam_search_output = RequestOutput(
|
| 185 |
+
request_id=request_id,
|
| 186 |
+
prompt=prompt_text,
|
| 187 |
+
outputs=[
|
| 188 |
+
CompletionOutput(text=beam.text,
|
| 189 |
+
cumulative_logprob=beam.cum_logprob,
|
| 190 |
+
token_ids=beam.tokens[tokenized_length:],
|
| 191 |
+
index=i,
|
| 192 |
+
logprobs=beam.logprobs,
|
| 193 |
+
finish_reason=beam.finish_reason if
|
| 194 |
+
beam.finish_reason is not None else "length",
|
| 195 |
+
stop_reason=beam.stop_reason)
|
| 196 |
+
for (i, beam) in enumerate(best_beams)
|
| 197 |
+
],
|
| 198 |
+
finished=True,
|
| 199 |
+
prompt_token_ids=prompt_token_ids,
|
| 200 |
+
prompt_logprobs=None)
|
| 201 |
+
|
| 202 |
+
yield beam_search_output
|
| 203 |
+
|
| 204 |
+
@abstractmethod
|
| 205 |
+
def encode(
|
| 206 |
+
self,
|
| 207 |
+
prompt: PromptType,
|
| 208 |
+
pooling_params: PoolingParams,
|
| 209 |
+
request_id: str,
|
| 210 |
+
lora_request: Optional[LoRARequest] = None,
|
| 211 |
+
trace_headers: Optional[Mapping[str, str]] = None,
|
| 212 |
+
priority: int = 0,
|
| 213 |
+
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
| 214 |
+
"""Generate outputs for a request from a pooling model."""
|
| 215 |
+
...
|
| 216 |
+
|
| 217 |
+
@abstractmethod
|
| 218 |
+
async def abort(self, request_id: str) -> None:
|
| 219 |
+
"""Abort a request.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
request_id: The unique id of the request.
|
| 223 |
+
"""
|
| 224 |
+
...
|
| 225 |
+
|
| 226 |
+
@abstractmethod
|
| 227 |
+
async def get_model_config(self) -> ModelConfig:
|
| 228 |
+
"""Get the model configuration of the vLLM engine."""
|
| 229 |
+
...
|
| 230 |
+
|
| 231 |
+
@abstractmethod
|
| 232 |
+
async def get_decoding_config(self) -> DecodingConfig:
|
| 233 |
+
"""Get the decoding configuration of the vLLM engine."""
|
| 234 |
+
...
|
| 235 |
+
|
| 236 |
+
@abstractmethod
|
| 237 |
+
async def get_input_preprocessor(self) -> InputPreprocessor:
|
| 238 |
+
"""Get the input processor of the vLLM engine."""
|
| 239 |
+
...
|
| 240 |
+
|
| 241 |
+
@abstractmethod
|
| 242 |
+
async def get_tokenizer(
|
| 243 |
+
self,
|
| 244 |
+
lora_request: Optional[LoRARequest] = None,
|
| 245 |
+
) -> AnyTokenizer:
|
| 246 |
+
"""Get the appropriate tokenizer for the request"""
|
| 247 |
+
...
|
| 248 |
+
|
| 249 |
+
@abstractmethod
|
| 250 |
+
async def is_tracing_enabled(self) -> bool:
|
| 251 |
+
...
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
async def do_log_stats(
|
| 255 |
+
self,
|
| 256 |
+
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
| 257 |
+
model_output: Optional[List[SamplerOutput]] = None,
|
| 258 |
+
) -> None:
|
| 259 |
+
...
|
| 260 |
+
|
| 261 |
+
@abstractmethod
|
| 262 |
+
async def check_health(self) -> None:
|
| 263 |
+
"""Raise if unhealthy"""
|
| 264 |
+
...
|
| 265 |
+
|
| 266 |
+
@abstractmethod
|
| 267 |
+
async def start_profile(self) -> None:
|
| 268 |
+
"""Start profiling the engine"""
|
| 269 |
+
...
|
| 270 |
+
|
| 271 |
+
@abstractmethod
|
| 272 |
+
async def stop_profile(self) -> None:
|
| 273 |
+
"""Start profiling the engine"""
|
| 274 |
+
...
|
| 275 |
+
|
| 276 |
+
@abstractmethod
|
| 277 |
+
async def reset_prefix_cache(self) -> None:
|
| 278 |
+
"""Reset the prefix cache"""
|
| 279 |
+
...
|
| 280 |
+
|
| 281 |
+
@abstractmethod
|
| 282 |
+
async def add_lora(self, lora_request: LoRARequest) -> None:
|
| 283 |
+
"""Load a new LoRA adapter into the engine for future requests."""
|
| 284 |
+
...
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
| 4 |
+
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
|
| 5 |
+
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
| 6 |
+
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
|
| 7 |
+
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
| 8 |
+
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
| 9 |
+
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
| 10 |
+
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
| 11 |
+
# `FalconConfig` class from the official HuggingFace transformers library.
|
| 12 |
+
from vllm.transformers_utils.configs.falcon import RWConfig
|
| 13 |
+
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
| 14 |
+
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
| 15 |
+
from vllm.transformers_utils.configs.jais import JAISConfig
|
| 16 |
+
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
| 17 |
+
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
| 18 |
+
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
| 19 |
+
from vllm.transformers_utils.configs.mpt import MPTConfig
|
| 20 |
+
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
| 21 |
+
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
| 22 |
+
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
|
| 23 |
+
from vllm.transformers_utils.configs.solar import SolarConfig
|
| 24 |
+
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
| 25 |
+
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"ChatGLMConfig",
|
| 29 |
+
"Cohere2Config",
|
| 30 |
+
"DbrxConfig",
|
| 31 |
+
"DeepseekVLV2Config",
|
| 32 |
+
"MPTConfig",
|
| 33 |
+
"RWConfig",
|
| 34 |
+
"H2OVLChatConfig",
|
| 35 |
+
"InternVLChatConfig",
|
| 36 |
+
"JAISConfig",
|
| 37 |
+
"MedusaConfig",
|
| 38 |
+
"EAGLEConfig",
|
| 39 |
+
"ExaoneConfig",
|
| 40 |
+
"MllamaConfig",
|
| 41 |
+
"MLPSpeculatorConfig",
|
| 42 |
+
"NemotronConfig",
|
| 43 |
+
"NVLM_D_Config",
|
| 44 |
+
"Olmo2Config",
|
| 45 |
+
"SolarConfig",
|
| 46 |
+
"Telechat2Config",
|
| 47 |
+
"UltravoxConfig",
|
| 48 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/arctic.cpython-311.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-311.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/cohere2.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-311.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|